diff --git a/.prettierignore b/.prettierignore index fefeeb5ed..69d746b86 100644 --- a/.prettierignore +++ b/.prettierignore @@ -10,4 +10,5 @@ .prettierignore .gitignore Dockerfile -*.conf \ No newline at end of file +*.conf +.gitmodules \ No newline at end of file diff --git a/src/server/GameServer.ts b/src/server/GameServer.ts index f7a60c0e7..8667cc5b9 100644 --- a/src/server/GameServer.ts +++ b/src/server/GameServer.ts @@ -19,6 +19,7 @@ import { GameType } from "../core/game/Game"; import { archive } from "./Archive"; import { Client } from "./Client"; import { slog } from "./StructuredLog"; +import { securityMiddleware } from "./Security"; export enum GamePhase { Lobby = "LOBBY", @@ -27,11 +28,6 @@ export enum GamePhase { } export class GameServer { - private rateLimiter = new RateLimiterMemory({ - points: 50, - duration: 1, // per 1 second - }); - private outOfSyncClients = new Set(); private maxGameDuration = 3 * 60 * 60 * 1000; // 3 hours @@ -125,58 +121,55 @@ export class GameServer { this.allClients.set(client.clientID, client); - client.ws.on("message", async (message: string) => { - try { - await this.rateLimiter.consume(client.ip); - } catch (error) { - console.warn(`Rate limit exceeded for ${client.ip}`); - return; - } - try { - let clientMsg: ClientMessage = null; + client.ws.on( + "message", + securityMiddleware.wsHandler(client.ip, async (message: string) => { try { - clientMsg = ClientMessageSchema.parse(JSON.parse(message)); + let clientMsg: ClientMessage = null; + try { + clientMsg = ClientMessageSchema.parse(JSON.parse(message)); + } catch (error) { + throw Error(`error parsing schema for ${client.ip}`); + } + if (this.allClients.has(clientMsg.clientID)) { + const client = this.allClients.get(clientMsg.clientID); + if (client.persistentID != clientMsg.persistentID) { + console.warn( + `Client ID ${clientMsg.clientID} sent incorrect id ${clientMsg.persistentID}, does not match persistent id ${client.persistentID}`, + ); + return; + } + } + + // Clear out persistent id to make sure it doesn't get sent to other clients. + clientMsg.persistentID = null; + + if (clientMsg.type == "intent") { + if (clientMsg.gameID == this.id) { + this.addIntent(clientMsg.intent); + } else { + console.warn( + `${this.id}: client ${clientMsg.clientID} sent to wrong game`, + ); + } + } + if (clientMsg.type == "ping") { + this.lastPingUpdate = Date.now(); + client.lastPing = Date.now(); + } + if (clientMsg.type == "hash") { + client.hashes.set(clientMsg.tick, clientMsg.hash); + } + if (clientMsg.type == "winner") { + this.winner = clientMsg.winner; + } } catch (error) { - throw Error(`error parsing schema for ${client.ip}`); + console.log( + `error handline websocket request in game server: ${error}`, + ); } - if (this.allClients.has(clientMsg.clientID)) { - const client = this.allClients.get(clientMsg.clientID); - if (client.persistentID != clientMsg.persistentID) { - console.warn( - `Client ID ${clientMsg.clientID} sent incorrect id ${clientMsg.persistentID}, does not match persistent id ${client.persistentID}`, - ); - return; - } - } - - // Clear out persistent id to make sure it doesn't get sent to other clients. - clientMsg.persistentID = null; - - if (clientMsg.type == "intent") { - if (clientMsg.gameID == this.id) { - this.addIntent(clientMsg.intent); - } else { - console.warn( - `${this.id}: client ${clientMsg.clientID} sent to wrong game`, - ); - } - } - if (clientMsg.type == "ping") { - this.lastPingUpdate = Date.now(); - client.lastPing = Date.now(); - } - if (clientMsg.type == "hash") { - client.hashes.set(clientMsg.tick, clientMsg.hash); - } - if (clientMsg.type == "winner") { - this.winner = clientMsg.winner; - } - } catch (error) { - console.log( - `error handline websocket request in game server: ${error}`, - ); - } - }); + }), + ); client.ws.on("close", () => { console.log(`${this.id}: client ${client.clientID} disconnected`); this.activeClients = this.activeClients.filter( diff --git a/src/server/Security.ts b/src/server/Security.ts new file mode 100644 index 000000000..5f38a427d --- /dev/null +++ b/src/server/Security.ts @@ -0,0 +1,103 @@ +// src/server/middleware/securityInterface.ts +import { Request, Response, NextFunction } from "express"; +import http from "http"; +import path from "path"; +import { fileURLToPath } from "url"; + +export enum LimiterType { + Get = "get", + Post = "post", + Put = "put", + WebSocket = "websocket", +} + +export interface SecurityMiddleware { + // The wrapper for request handlers with optional rate limiting + httpHandler: ( + fn: (req: Request, res: Response, next: NextFunction) => Promise, + limiterType: LimiterType, + ) => (req: Request, res: Response, next: NextFunction) => Promise; + + // The wrapper for WebSocket message handlers with rate limiting + wsHandler: ( + req: http.IncomingMessage | string, + fn: (message: string) => Promise, + ) => (message: string) => Promise; +} + +// Function to get the appropriate security middleware implementation +async function getSecurityMiddleware(): Promise { + try { + // Get the current file's directory + const __filename = fileURLToPath(import.meta.url); + const __dirname = path.dirname(__filename); + + try { + // Use dynamic import for ES modules - without file extension + // ts-node will resolve this correctly + const module = await import( + "./security-middleware/RealSecurityMiddleware" + ); + + if (!module.RealSecurityMiddleware) { + throw new Error("RealSecurityMiddleware class not found in module"); + } + + console.log("Successfully loaded real security middleware"); + return new module.RealSecurityMiddleware(); + } catch (error) { + console.log("Failed to load real security middleware:", error); + return new NoOpSecurityMiddleware(); + } + } catch (e) { + // Fall back to no-op if real implementation isn't available + console.log("using no-op security middleware", e); + return new NoOpSecurityMiddleware(); + } +} + +export class NoOpSecurityMiddleware implements SecurityMiddleware { + // Simple pass-through with no rate limiting + httpHandler( + fn: (req: Request, res: Response, next: NextFunction) => Promise, + limiterType: LimiterType, + ) { + return async (req: Request, res: Response, next: NextFunction) => { + try { + await fn(req, res, next); + } catch (error) { + next(error); + } + }; + } + + // Corrected implementation for WebSocket handler wrapper + wsHandler( + req: http.IncomingMessage | string, + fn: (message: string) => Promise, + ) { + return async (message: string) => { + try { + await fn(message); + } catch (error) { + console.error("WebSocket handler error:", error); + } + }; + } +} + +// Initialize the security middleware with a default implementation +// We'll use the NoOpSecurityMiddleware initially and then replace it +// with the real implementation once it's loaded +export const securityMiddleware: SecurityMiddleware = + new NoOpSecurityMiddleware(); + +// Immediately try to load the real middleware +getSecurityMiddleware() + .then((middleware) => { + // Replace the methods of securityMiddleware with those from the loaded middleware + Object.assign(securityMiddleware, middleware); + }) + .catch((error) => { + console.error("Failed to initialize security middleware:", error); + }); diff --git a/src/server/Worker.ts b/src/server/Worker.ts index 3b3932b04..18efc82b2 100644 --- a/src/server/Worker.ts +++ b/src/server/Worker.ts @@ -13,6 +13,7 @@ import { GameConfig, GameRecord, LogSeverity } from "../core/Schemas"; import { slog } from "./StructuredLog"; import { GameType } from "../core/game/Game"; import { archive } from "./Archive"; +import { LimiterType, securityMiddleware } from "./Security"; const config = getServerConfig(); @@ -76,32 +77,10 @@ export function startWorker() { duration: 240, // 4 minutes }); - // Async handler with rate limiting - const asyncHandler = - (fn: Function, limiter = null) => - async (req: Request, res: Response, next: NextFunction) => { - try { - if (limiter) { - if (!isLocalhost(req)) { - const clientIP = req.ip || req.socket.remoteAddress || "unknown"; - try { - await limiter.consume(clientIP); - } catch (error) { - console.warn(`Rate limited for IP ${clientIP}`); - return res.status(429).json({ error: "Too many requests" }); - } - } - } - await fn(req, res, next); - } catch (error) { - next(error); - } - }; - // Endpoint to create a private lobby app.post( "/create_game/:id", - asyncHandler(async (req, res) => { + securityMiddleware.httpHandler(async (req, res) => { const id = req.params.id; if (!id) { console.warn(`cannot create game, id not found`); @@ -132,13 +111,13 @@ export function startWorker() { `Worker ${workerId}: IP ${clientIP} creating game ${game.isPublic() ? "Public" : "Private"} with id ${id}`, ); res.json(game.gameInfo()); - }, updateRateLimiter), + }, LimiterType.Post), ); // Add other endpoints from your original server app.post( "/start_game/:id", - asyncHandler(async (req, res) => { + securityMiddleware.httpHandler(async (req, res) => { console.log(`starting private lobby with id ${req.params.id}`); const game = gm.game(req.params.id); if (!game) { @@ -153,12 +132,12 @@ export function startWorker() { } game.start(); res.status(200).json({ success: true }); - }, updateRateLimiter), + }, LimiterType.Post), ); app.put( "/game/:id", - asyncHandler(async (req, res) => { + securityMiddleware.httpHandler(async (req, res) => { // TODO: only update public game if from local host const lobbyID = req.params.id; if (req.body.gameType == GameType.Public) { @@ -184,34 +163,34 @@ export function startWorker() { disableNPCs: req.body.disableNPCs, }); res.status(200).json({ success: true }); - }), + }, LimiterType.Put), ); app.get( "/game/:id/exists", - asyncHandler(async (req, res) => { + securityMiddleware.httpHandler(async (req, res) => { const lobbyId = req.params.id; res.json({ exists: gm.game(lobbyId) != null, }); - }), + }, LimiterType.Get), ); app.get( "/game/:id", - asyncHandler(async (req, res) => { + securityMiddleware.httpHandler(async (req, res) => { const game = gm.game(req.params.id); if (game == null) { console.log(`lobby ${req.params.id} not found`); return res.status(404).json({ error: "Game not found" }); } res.json(game.gameInfo()); - }), + }, LimiterType.Get), ); app.post( "/archive_singleplayer_game", - asyncHandler(async (req, res) => { + securityMiddleware.httpHandler(async (req, res) => { const gameRecord: GameRecord = req.body; const clientIP = req.ip || req.socket.remoteAddress || "unknown"; @@ -225,71 +204,74 @@ export function startWorker() { res.json({ success: true, }); - }, updateRateLimiter), + }, LimiterType.Post), ); // WebSocket handling wss.on("connection", (ws: WebSocket, req) => { - ws.on("message", async (message: string) => { - const forwarded = req.headers["x-forwarded-for"]; - const ip = Array.isArray(forwarded) - ? forwarded[0] - : forwarded || req.socket.remoteAddress; - try { - await rateLimiter.consume(ip); - } catch (error) { - console.warn(`rate limit exceeded for ${ip}`); - return; - } - - try { - // Process WebSocket messages as in your original code - // Parse and handle client messages - const clientMsg = JSON.parse(message.toString()); - - if (clientMsg.type == "join") { - // Verify this worker should handle this game - const expectedWorkerId = config.workerIndex(clientMsg.gameID); - if (expectedWorkerId !== workerId) { - console.warn( - `Worker mismatch: Game ${clientMsg.gameID} should be on worker ${expectedWorkerId}, but this is worker ${workerId}`, - ); - return; - } - - // Create client and add to game - const client = new Client( - clientMsg.clientID, - clientMsg.persistentID, - ip, - clientMsg.username, - ws, - ); - - const wasFound = gm.addClient( - client, - clientMsg.gameID, - clientMsg.lastTurn, - ); - - if (!wasFound) { - console.log( - `game ${clientMsg.gameID} not found on worker ${workerId}`, - ); - // Handle game not found case - } + ws.on( + "message", + securityMiddleware.wsHandler(req, async (message: string) => { + const forwarded = req.headers["x-forwarded-for"]; + const ip = Array.isArray(forwarded) + ? forwarded[0] + : forwarded || req.socket.remoteAddress; + try { + await rateLimiter.consume(ip); + } catch (error) { + console.warn(`rate limit exceeded for ${ip}`); + return; } - // Handle other message types - } catch (error) { - console.warn( - `error handling websocket message for ${ip}: ${error}`.substring( - 0, - 250, - ), - ); - } - }); + try { + // Process WebSocket messages as in your original code + // Parse and handle client messages + const clientMsg = JSON.parse(message.toString()); + + if (clientMsg.type == "join") { + // Verify this worker should handle this game + const expectedWorkerId = config.workerIndex(clientMsg.gameID); + if (expectedWorkerId !== workerId) { + console.warn( + `Worker mismatch: Game ${clientMsg.gameID} should be on worker ${expectedWorkerId}, but this is worker ${workerId}`, + ); + return; + } + + // Create client and add to game + const client = new Client( + clientMsg.clientID, + clientMsg.persistentID, + ip, + clientMsg.username, + ws, + ); + + const wasFound = gm.addClient( + client, + clientMsg.gameID, + clientMsg.lastTurn, + ); + + if (!wasFound) { + console.log( + `game ${clientMsg.gameID} not found on worker ${workerId}`, + ); + // Handle game not found case + } + } + + // Handle other message types + } catch (error) { + console.warn( + `error handling websocket message for ${ip}: ${error}`.substring( + 0, + 250, + ), + ); + } + }), + ); ws.on("error", (error: Error) => { if ((error as any).code === "WS_ERR_UNEXPECTED_RSV_1") { diff --git a/src/server/security-middleware b/src/server/security-middleware index 743460db2..8fc4ab3ef 160000 --- a/src/server/security-middleware +++ b/src/server/security-middleware @@ -1 +1 @@ -Subproject commit 743460db25cf2dd151f59c73f25f663e224f0b88 +Subproject commit 8fc4ab3ef93ed92e58d53fcaef7824a95636cf3f