diff --git a/packages/base-driver/lib/express/middleware.js b/packages/base-driver/lib/express/middleware.js index f562106e4..e9db64def 100644 --- a/packages/base-driver/lib/express/middleware.js +++ b/packages/base-driver/lib/express/middleware.js @@ -86,6 +86,39 @@ export function defaultToJSONContentType(req, res, next) { next(); } +/** + * Core function to handle WebSocket upgrade requests by matching the request path + * against registered WebSocket handlers in the webSocketsMapping. + * + * @param {import('http').IncomingMessage} req - The HTTP request + * @param {import('stream').Duplex} socket - The network socket + * @param {Buffer} head - The first packet of the upgraded stream + * @param {import('@appium/types').StringRecord} webSocketsMapping - Mapping of paths to WebSocket servers + * @returns {boolean} - Returns true if the upgrade was handled, false otherwise + */ +export function tryHandleWebSocketUpgrade(req, socket, head, webSocketsMapping) { + if (_.toLower(req.headers?.upgrade) !== 'websocket') { + return false; + } + + let currentPathname; + try { + currentPathname = new URL(req.url ?? '', 'http://localhost').pathname; + } catch { + currentPathname = req.url ?? ''; + } + for (const [pathname, wsServer] of _.toPairs(webSocketsMapping)) { + if (match(pathname)(currentPathname)) { + wsServer.handleUpgrade(req, socket, head, (ws) => { + wsServer.emit('connection', ws, req); + }); + return true; + } + } + log.info(`Did not match the websocket upgrade request at ${currentPathname} to any known route`); + return false; +} + /** * * @param {import('@appium/types').StringRecord} webSocketsMapping @@ -93,23 +126,9 @@ export function defaultToJSONContentType(req, res, next) { */ export function handleUpgrade(webSocketsMapping) { return (req, res, next) => { - if (!req.headers?.upgrade || _.toLower(req.headers.upgrade) !== 'websocket') { - return next(); + if (tryHandleWebSocketUpgrade(req, req.socket, Buffer.from(''), webSocketsMapping)) { + return; } - let currentPathname; - try { - currentPathname = new URL(req.url ?? '').pathname; - } catch { - currentPathname = req.url ?? ''; - } - for (const [pathname, wsServer] of _.toPairs(webSocketsMapping)) { - if (match(pathname)(currentPathname)) { - return wsServer.handleUpgrade(req, req.socket, Buffer.from(''), (ws) => { - wsServer.emit('connection', ws, req); - }); - } - } - log.info(`Did not match the websocket upgrade request at ${currentPathname} to any known route`); next(); }; } diff --git a/packages/base-driver/lib/express/server.js b/packages/base-driver/lib/express/server.js index 6e4b42a95..ddb5d5543 100644 --- a/packages/base-driver/lib/express/server.js +++ b/packages/base-driver/lib/express/server.js @@ -14,6 +14,7 @@ import { allowCrossDomainAsyncExecute, handleIdempotency, handleUpgrade, + tryHandleWebSocketUpgrade, catch404Handler, handleLogContext, } from './middleware'; @@ -31,6 +32,158 @@ import {fs, timing} from '@appium/support'; const KEEP_ALIVE_TIMEOUT_MS = 10 * 60 * 1000; // 10 minutes +/** + * + * @param {ServerOpts} opts + * @returns {Promise} + */ +export async function server(opts) { + const { + routeConfiguringFunction, + port, + hostname, + cliArgs = /** @type {import('@appium/types').ServerArgs} */ ({}), + allowCors = true, + basePath = DEFAULT_BASE_PATH, + extraMethodMap = {}, + serverUpdaters = [], + keepAliveTimeout = KEEP_ALIVE_TIMEOUT_MS, + requestTimeout, + } = opts; + + const app = express(); + const httpServer = await createServer(app, cliArgs); + + return await new B(async (resolve, reject) => { + // we put an async function as the promise constructor because we want some things to happen in + // serial (application of plugin updates, for example). But we still need to use a promise here + // because some elements of server start failure only happen in httpServer listeners. So the + // way we resolve it is to use an async function here but to wrap all the inner logic in + // try/catch so any errors can be passed to reject. + try { + const appiumServer = configureHttp({ + httpServer, + reject, + keepAliveTimeout, + gracefulShutdownTimeout: cliArgs.shutdownTimeout, + }); + const useLegacyUpgradeHandler = !hasShouldUpgradeCallback(httpServer); + configureServer({ + app, + addRoutes: routeConfiguringFunction, + allowCors, + basePath, + extraMethodMap, + webSocketsMapping: appiumServer.webSocketsMapping, + useLegacyUpgradeHandler, + }); + // allow extensions to update the app and http server objects + for (const updater of serverUpdaters) { + await updater(app, appiumServer, cliArgs); + } + + // once all configurations and updaters have been applied, make sure to set up a catchall + // handler so that anything unknown 404s. But do this after everything else since we don't + // want to block extensions' ability to add routes if they want. + app.all('/*all', catch404Handler); + + await startServer({ + httpServer, + hostname, + port, + keepAliveTimeout, + requestTimeout, + }); + + resolve(appiumServer); + } catch (err) { + reject(err); + } + }); +} + +/** + * Sets up some Express middleware and stuff + * @param {ConfigureServerOpts} opts + */ +export function configureServer({ + app, + addRoutes, + allowCors = true, + basePath = DEFAULT_BASE_PATH, + extraMethodMap = {}, + webSocketsMapping = {}, + useLegacyUpgradeHandler = true, +}) { + basePath = normalizeBasePath(basePath); + + app.use(endLogFormatter); + app.use(handleLogContext); + + // set up static assets + app.use(favicon(path.resolve(STATIC_DIR, 'favicon.ico'))); + // eslint-disable-next-line import/no-named-as-default-member + app.use(express.static(STATIC_DIR)); + + // crash routes, for testing + app.use(`${basePath}/produce_error`, produceError); + app.use(`${basePath}/crash`, produceCrash); + + // Only use legacy Express middleware for WebSocket upgrades if shouldUpgradeCallback is not available + // When shouldUpgradeCallback is available, upgrades are handled directly on the HTTP server + // to avoid Express middleware timeout issues with long-lived connections + if (useLegacyUpgradeHandler) { + app.use(handleUpgrade(webSocketsMapping)); + } + if (allowCors) { + app.use(allowCrossDomain); + } else { + app.use(allowCrossDomainAsyncExecute(basePath)); + } + app.use(handleIdempotency); + app.use(defaultToJSONContentType); + app.use(bodyParser.urlencoded({extended: true})); + app.use(methodOverride()); + app.use(catchAllHandler); + + // make sure appium never fails because of a file size upload limit + app.use(bodyParser.json({limit: '1gb'})); + + // set up start logging (which depends on bodyParser doing its thing) + app.use(startLogFormatter); + + addRoutes(app, {basePath, extraMethodMap}); + + // dynamic routes for testing, etc. + app.all('/welcome', welcome); + app.all('/test/guinea-pig', guineaPig); + app.all('/test/guinea-pig-scrollable', guineaPigScrollable); + app.all('/test/guinea-pig-app-banner', guineaPigAppBanner); +} + +/** + * Normalize base path string + * @param {string} basePath + * @returns {string} + */ +export function normalizeBasePath(basePath) { + if (!_.isString(basePath)) { + throw new Error(`Invalid path prefix ${basePath}`); + } + + // ensure the path prefix does not end in '/', since our method map + // starts all paths with '/' + basePath = basePath.replace(/\/$/, ''); + + // likewise, ensure the path prefix does always START with /, unless the path + // is empty meaning no base path at all + if (basePath !== '' && !basePath.startsWith('/')) { + basePath = `/${basePath}`; + } + + return basePath; +} + /** * * @param {import('express').Express} app @@ -70,127 +223,6 @@ async function createServer (app, cliArgs) { }, app); } -/** - * - * @param {ServerOpts} opts - * @returns {Promise} - */ -export async function server(opts) { - const { - routeConfiguringFunction, - port, - hostname, - cliArgs = /** @type {import('@appium/types').ServerArgs} */ ({}), - allowCors = true, - basePath = DEFAULT_BASE_PATH, - extraMethodMap = {}, - serverUpdaters = [], - keepAliveTimeout = KEEP_ALIVE_TIMEOUT_MS, - requestTimeout, - } = opts; - - const app = express(); - const httpServer = await createServer(app, cliArgs); - - return await new B(async (resolve, reject) => { - // we put an async function as the promise constructor because we want some things to happen in - // serial (application of plugin updates, for example). But we still need to use a promise here - // because some elements of server start failure only happen in httpServer listeners. So the - // way we resolve it is to use an async function here but to wrap all the inner logic in - // try/catch so any errors can be passed to reject. - try { - const appiumServer = configureHttp({ - httpServer, - reject, - keepAliveTimeout, - gracefulShutdownTimeout: cliArgs.shutdownTimeout, - }); - configureServer({ - app, - addRoutes: routeConfiguringFunction, - allowCors, - basePath, - extraMethodMap, - webSocketsMapping: appiumServer.webSocketsMapping, - }); - // allow extensions to update the app and http server objects - for (const updater of serverUpdaters) { - await updater(app, appiumServer, cliArgs); - } - - // once all configurations and updaters have been applied, make sure to set up a catchall - // handler so that anything unknown 404s. But do this after everything else since we don't - // want to block extensions' ability to add routes if they want. - app.all('/*all', catch404Handler); - - await startServer({ - httpServer, - hostname, - port, - keepAliveTimeout, - requestTimeout, - }); - - resolve(appiumServer); - } catch (err) { - reject(err); - } - }); -} - -/** - * Sets up some Express middleware and stuff - * @param {ConfigureServerOpts} opts - */ -export function configureServer({ - app, - addRoutes, - allowCors = true, - basePath = DEFAULT_BASE_PATH, - extraMethodMap = {}, - webSocketsMapping = {}, -}) { - basePath = normalizeBasePath(basePath); - - app.use(endLogFormatter); - app.use(handleLogContext); - - // set up static assets - app.use(favicon(path.resolve(STATIC_DIR, 'favicon.ico'))); - // eslint-disable-next-line import/no-named-as-default-member - app.use(express.static(STATIC_DIR)); - - // crash routes, for testing - app.use(`${basePath}/produce_error`, produceError); - app.use(`${basePath}/crash`, produceCrash); - - app.use(handleUpgrade(webSocketsMapping)); - if (allowCors) { - app.use(allowCrossDomain); - } else { - app.use(allowCrossDomainAsyncExecute(basePath)); - } - app.use(handleIdempotency); - app.use(defaultToJSONContentType); - app.use(bodyParser.urlencoded({extended: true})); - app.use(methodOverride()); - app.use(catchAllHandler); - - // make sure appium never fails because of a file size upload limit - app.use(bodyParser.json({limit: '1gb'})); - - // set up start logging (which depends on bodyParser doing its thing) - app.use(startLogFormatter); - - addRoutes(app, {basePath, extraMethodMap}); - - // dynamic routes for testing, etc. - app.all('/welcome', welcome); - app.all('/test/guinea-pig', guineaPig); - app.all('/test/guinea-pig-scrollable', guineaPigScrollable); - app.all('/test/guinea-pig-app-banner', guineaPigAppBanner); -} - /** * Monkeypatches the `http.Server` instance and returns a {@linkcode AppiumServer}. * This function _mutates_ the `httpServer` parameter. @@ -212,6 +244,20 @@ function configureHttp({httpServer, reject, keepAliveTimeout, gracefulShutdownTi return Boolean(this['_spdyState']?.secure); }; + // This avoids Express middleware timeout issues with long-lived WebSocket connections + // See: https://github.com/appium/appium/issues/20760 + // See: https://github.com/nodejs/node/pull/59824 + if (hasShouldUpgradeCallback(httpServer)) { + // shouldUpgradeCallback only returns a boolean to indicate if the upgrade should proceed + // eslint-disable-next-line dot-notation + appiumServer['shouldUpgradeCallback'] = (req) => _.toLower(req.headers?.upgrade) === 'websocket'; + appiumServer.on('upgrade', (req, socket, head) => { + if (!tryHandleWebSocketUpgrade(req, socket, head, appiumServer.webSocketsMapping)) { + socket.destroy(); + } + }); + } + // http.Server.close() only stops new connections, but we need to wait until // all connections are closed and the `close` event is emitted const originalClose = appiumServer.close.bind(appiumServer); @@ -294,29 +340,23 @@ async function startServer({ } /** - * Normalize base path string - * @param {string} basePath - * @returns {string} + * Checks if the provided server instance supports `shouldUpgradeCallback`. + * This feature was added in Node.js v22.21.0 (LTS) and v24.9.0. + * @param {import('http').Server} server - The HTTP server instance to check + * @returns {boolean} */ -export function normalizeBasePath(basePath) { - if (!_.isString(basePath)) { - throw new Error(`Invalid path prefix ${basePath}`); +function hasShouldUpgradeCallback(server) { + // Check if shouldUpgradeCallback is available on http.Server + // This is a runtime check that works regardless of TypeScript types + try { + // Use bracket notation to access property that may not exist in type definitions + // eslint-disable-next-line dot-notation + return typeof server['shouldUpgradeCallback'] !== 'undefined'; + } catch { + return false; } - - // ensure the path prefix does not end in '/', since our method map - // starts all paths with '/' - basePath = basePath.replace(/\/$/, ''); - - // likewise, ensure the path prefix does always START with /, unless the path - // is empty meaning no base path at all - if (basePath !== '' && !basePath.startsWith('/')) { - basePath = `/${basePath}`; - } - - return basePath; } - /** * Options for {@linkcode startServer}. * @typedef StartServerOpts @@ -386,4 +426,6 @@ export function normalizeBasePath(basePath) { * @property {string} [basePath] * @property {MethodMap} [extraMethodMap] * @property {import('@appium/types').StringRecord} [webSocketsMapping={}] + * @property {boolean} [useLegacyUpgradeHandler=true] - Whether to use legacy Express middleware for WebSocket upgrades. + * Set to false when using shouldUpgradeCallback on the HTTP server (Node.js >= 22.21.0 or >= 24.9.0). */ diff --git a/packages/support/test/e2e/fs.e2e.spec.js b/packages/support/test/e2e/fs.e2e.spec.js index 2b912b8f0..5f304300a 100644 --- a/packages/support/test/e2e/fs.e2e.spec.js +++ b/packages/support/test/e2e/fs.e2e.spec.js @@ -81,7 +81,7 @@ describe('fs', function () { // Mock fsPromises.rename to simulate EXDEV error const originalRename = fs.rename; - fs.rename = async (src, dst) => { + fs.rename = async () => { const err = new Error('cross-device link not permitted'); err.code = 'EXDEV'; throw err;