added security middleware

This commit is contained in:
Evan
2025-03-01 20:48:49 -08:00
parent 0b34e2b5e6
commit daa9820c32
5 changed files with 227 additions and 148 deletions
+2 -1
View File
@@ -10,4 +10,5 @@
.prettierignore
.gitignore
Dockerfile
*.conf
*.conf
.gitmodules
+47 -54
View File
@@ -19,6 +19,7 @@ import { GameType } from "../core/game/Game";
import { archive } from "./Archive";
import { Client } from "./Client";
import { slog } from "./StructuredLog";
import { securityMiddleware } from "./Security";
export enum GamePhase {
Lobby = "LOBBY",
@@ -27,11 +28,6 @@ export enum GamePhase {
}
export class GameServer {
private rateLimiter = new RateLimiterMemory({
points: 50,
duration: 1, // per 1 second
});
private outOfSyncClients = new Set<ClientID>();
private maxGameDuration = 3 * 60 * 60 * 1000; // 3 hours
@@ -125,58 +121,55 @@ export class GameServer {
this.allClients.set(client.clientID, client);
client.ws.on("message", async (message: string) => {
try {
await this.rateLimiter.consume(client.ip);
} catch (error) {
console.warn(`Rate limit exceeded for ${client.ip}`);
return;
}
try {
let clientMsg: ClientMessage = null;
client.ws.on(
"message",
securityMiddleware.wsHandler(client.ip, async (message: string) => {
try {
clientMsg = ClientMessageSchema.parse(JSON.parse(message));
let clientMsg: ClientMessage = null;
try {
clientMsg = ClientMessageSchema.parse(JSON.parse(message));
} catch (error) {
throw Error(`error parsing schema for ${client.ip}`);
}
if (this.allClients.has(clientMsg.clientID)) {
const client = this.allClients.get(clientMsg.clientID);
if (client.persistentID != clientMsg.persistentID) {
console.warn(
`Client ID ${clientMsg.clientID} sent incorrect id ${clientMsg.persistentID}, does not match persistent id ${client.persistentID}`,
);
return;
}
}
// Clear out persistent id to make sure it doesn't get sent to other clients.
clientMsg.persistentID = null;
if (clientMsg.type == "intent") {
if (clientMsg.gameID == this.id) {
this.addIntent(clientMsg.intent);
} else {
console.warn(
`${this.id}: client ${clientMsg.clientID} sent to wrong game`,
);
}
}
if (clientMsg.type == "ping") {
this.lastPingUpdate = Date.now();
client.lastPing = Date.now();
}
if (clientMsg.type == "hash") {
client.hashes.set(clientMsg.tick, clientMsg.hash);
}
if (clientMsg.type == "winner") {
this.winner = clientMsg.winner;
}
} catch (error) {
throw Error(`error parsing schema for ${client.ip}`);
console.log(
`error handline websocket request in game server: ${error}`,
);
}
if (this.allClients.has(clientMsg.clientID)) {
const client = this.allClients.get(clientMsg.clientID);
if (client.persistentID != clientMsg.persistentID) {
console.warn(
`Client ID ${clientMsg.clientID} sent incorrect id ${clientMsg.persistentID}, does not match persistent id ${client.persistentID}`,
);
return;
}
}
// Clear out persistent id to make sure it doesn't get sent to other clients.
clientMsg.persistentID = null;
if (clientMsg.type == "intent") {
if (clientMsg.gameID == this.id) {
this.addIntent(clientMsg.intent);
} else {
console.warn(
`${this.id}: client ${clientMsg.clientID} sent to wrong game`,
);
}
}
if (clientMsg.type == "ping") {
this.lastPingUpdate = Date.now();
client.lastPing = Date.now();
}
if (clientMsg.type == "hash") {
client.hashes.set(clientMsg.tick, clientMsg.hash);
}
if (clientMsg.type == "winner") {
this.winner = clientMsg.winner;
}
} catch (error) {
console.log(
`error handline websocket request in game server: ${error}`,
);
}
});
}),
);
client.ws.on("close", () => {
console.log(`${this.id}: client ${client.clientID} disconnected`);
this.activeClients = this.activeClients.filter(
+103
View File
@@ -0,0 +1,103 @@
// src/server/middleware/securityInterface.ts
import { Request, Response, NextFunction } from "express";
import http from "http";
import path from "path";
import { fileURLToPath } from "url";
export enum LimiterType {
Get = "get",
Post = "post",
Put = "put",
WebSocket = "websocket",
}
export interface SecurityMiddleware {
// The wrapper for request handlers with optional rate limiting
httpHandler: (
fn: (req: Request, res: Response, next: NextFunction) => Promise<any>,
limiterType: LimiterType,
) => (req: Request, res: Response, next: NextFunction) => Promise<void>;
// The wrapper for WebSocket message handlers with rate limiting
wsHandler: (
req: http.IncomingMessage | string,
fn: (message: string) => Promise<void>,
) => (message: string) => Promise<void>;
}
// Function to get the appropriate security middleware implementation
async function getSecurityMiddleware(): Promise<SecurityMiddleware> {
try {
// Get the current file's directory
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
try {
// Use dynamic import for ES modules - without file extension
// ts-node will resolve this correctly
const module = await import(
"./security-middleware/RealSecurityMiddleware"
);
if (!module.RealSecurityMiddleware) {
throw new Error("RealSecurityMiddleware class not found in module");
}
console.log("Successfully loaded real security middleware");
return new module.RealSecurityMiddleware();
} catch (error) {
console.log("Failed to load real security middleware:", error);
return new NoOpSecurityMiddleware();
}
} catch (e) {
// Fall back to no-op if real implementation isn't available
console.log("using no-op security middleware", e);
return new NoOpSecurityMiddleware();
}
}
export class NoOpSecurityMiddleware implements SecurityMiddleware {
// Simple pass-through with no rate limiting
httpHandler(
fn: (req: Request, res: Response, next: NextFunction) => Promise<any>,
limiterType: LimiterType,
) {
return async (req: Request, res: Response, next: NextFunction) => {
try {
await fn(req, res, next);
} catch (error) {
next(error);
}
};
}
// Corrected implementation for WebSocket handler wrapper
wsHandler(
req: http.IncomingMessage | string,
fn: (message: string) => Promise<void>,
) {
return async (message: string) => {
try {
await fn(message);
} catch (error) {
console.error("WebSocket handler error:", error);
}
};
}
}
// Initialize the security middleware with a default implementation
// We'll use the NoOpSecurityMiddleware initially and then replace it
// with the real implementation once it's loaded
export const securityMiddleware: SecurityMiddleware =
new NoOpSecurityMiddleware();
// Immediately try to load the real middleware
getSecurityMiddleware()
.then((middleware) => {
// Replace the methods of securityMiddleware with those from the loaded middleware
Object.assign(securityMiddleware, middleware);
})
.catch((error) => {
console.error("Failed to initialize security middleware:", error);
});
+74 -92
View File
@@ -13,6 +13,7 @@ import { GameConfig, GameRecord, LogSeverity } from "../core/Schemas";
import { slog } from "./StructuredLog";
import { GameType } from "../core/game/Game";
import { archive } from "./Archive";
import { LimiterType, securityMiddleware } from "./Security";
const config = getServerConfig();
@@ -76,32 +77,10 @@ export function startWorker() {
duration: 240, // 4 minutes
});
// Async handler with rate limiting
const asyncHandler =
(fn: Function, limiter = null) =>
async (req: Request, res: Response, next: NextFunction) => {
try {
if (limiter) {
if (!isLocalhost(req)) {
const clientIP = req.ip || req.socket.remoteAddress || "unknown";
try {
await limiter.consume(clientIP);
} catch (error) {
console.warn(`Rate limited for IP ${clientIP}`);
return res.status(429).json({ error: "Too many requests" });
}
}
}
await fn(req, res, next);
} catch (error) {
next(error);
}
};
// Endpoint to create a private lobby
app.post(
"/create_game/:id",
asyncHandler(async (req, res) => {
securityMiddleware.httpHandler(async (req, res) => {
const id = req.params.id;
if (!id) {
console.warn(`cannot create game, id not found`);
@@ -132,13 +111,13 @@ export function startWorker() {
`Worker ${workerId}: IP ${clientIP} creating game ${game.isPublic() ? "Public" : "Private"} with id ${id}`,
);
res.json(game.gameInfo());
}, updateRateLimiter),
}, LimiterType.Post),
);
// Add other endpoints from your original server
app.post(
"/start_game/:id",
asyncHandler(async (req, res) => {
securityMiddleware.httpHandler(async (req, res) => {
console.log(`starting private lobby with id ${req.params.id}`);
const game = gm.game(req.params.id);
if (!game) {
@@ -153,12 +132,12 @@ export function startWorker() {
}
game.start();
res.status(200).json({ success: true });
}, updateRateLimiter),
}, LimiterType.Post),
);
app.put(
"/game/:id",
asyncHandler(async (req, res) => {
securityMiddleware.httpHandler(async (req, res) => {
// TODO: only update public game if from local host
const lobbyID = req.params.id;
if (req.body.gameType == GameType.Public) {
@@ -184,34 +163,34 @@ export function startWorker() {
disableNPCs: req.body.disableNPCs,
});
res.status(200).json({ success: true });
}),
}, LimiterType.Put),
);
app.get(
"/game/:id/exists",
asyncHandler(async (req, res) => {
securityMiddleware.httpHandler(async (req, res) => {
const lobbyId = req.params.id;
res.json({
exists: gm.game(lobbyId) != null,
});
}),
}, LimiterType.Get),
);
app.get(
"/game/:id",
asyncHandler(async (req, res) => {
securityMiddleware.httpHandler(async (req, res) => {
const game = gm.game(req.params.id);
if (game == null) {
console.log(`lobby ${req.params.id} not found`);
return res.status(404).json({ error: "Game not found" });
}
res.json(game.gameInfo());
}),
}, LimiterType.Get),
);
app.post(
"/archive_singleplayer_game",
asyncHandler(async (req, res) => {
securityMiddleware.httpHandler(async (req, res) => {
const gameRecord: GameRecord = req.body;
const clientIP = req.ip || req.socket.remoteAddress || "unknown";
@@ -225,71 +204,74 @@ export function startWorker() {
res.json({
success: true,
});
}, updateRateLimiter),
}, LimiterType.Post),
);
// WebSocket handling
wss.on("connection", (ws: WebSocket, req) => {
ws.on("message", async (message: string) => {
const forwarded = req.headers["x-forwarded-for"];
const ip = Array.isArray(forwarded)
? forwarded[0]
: forwarded || req.socket.remoteAddress;
try {
await rateLimiter.consume(ip);
} catch (error) {
console.warn(`rate limit exceeded for ${ip}`);
return;
}
try {
// Process WebSocket messages as in your original code
// Parse and handle client messages
const clientMsg = JSON.parse(message.toString());
if (clientMsg.type == "join") {
// Verify this worker should handle this game
const expectedWorkerId = config.workerIndex(clientMsg.gameID);
if (expectedWorkerId !== workerId) {
console.warn(
`Worker mismatch: Game ${clientMsg.gameID} should be on worker ${expectedWorkerId}, but this is worker ${workerId}`,
);
return;
}
// Create client and add to game
const client = new Client(
clientMsg.clientID,
clientMsg.persistentID,
ip,
clientMsg.username,
ws,
);
const wasFound = gm.addClient(
client,
clientMsg.gameID,
clientMsg.lastTurn,
);
if (!wasFound) {
console.log(
`game ${clientMsg.gameID} not found on worker ${workerId}`,
);
// Handle game not found case
}
ws.on(
"message",
securityMiddleware.wsHandler(req, async (message: string) => {
const forwarded = req.headers["x-forwarded-for"];
const ip = Array.isArray(forwarded)
? forwarded[0]
: forwarded || req.socket.remoteAddress;
try {
await rateLimiter.consume(ip);
} catch (error) {
console.warn(`rate limit exceeded for ${ip}`);
return;
}
// Handle other message types
} catch (error) {
console.warn(
`error handling websocket message for ${ip}: ${error}`.substring(
0,
250,
),
);
}
});
try {
// Process WebSocket messages as in your original code
// Parse and handle client messages
const clientMsg = JSON.parse(message.toString());
if (clientMsg.type == "join") {
// Verify this worker should handle this game
const expectedWorkerId = config.workerIndex(clientMsg.gameID);
if (expectedWorkerId !== workerId) {
console.warn(
`Worker mismatch: Game ${clientMsg.gameID} should be on worker ${expectedWorkerId}, but this is worker ${workerId}`,
);
return;
}
// Create client and add to game
const client = new Client(
clientMsg.clientID,
clientMsg.persistentID,
ip,
clientMsg.username,
ws,
);
const wasFound = gm.addClient(
client,
clientMsg.gameID,
clientMsg.lastTurn,
);
if (!wasFound) {
console.log(
`game ${clientMsg.gameID} not found on worker ${workerId}`,
);
// Handle game not found case
}
}
// Handle other message types
} catch (error) {
console.warn(
`error handling websocket message for ${ip}: ${error}`.substring(
0,
250,
),
);
}
}),
);
ws.on("error", (error: Error) => {
if ((error as any).code === "WS_ERR_UNEXPECTED_RSV_1") {