diff --git a/src/core/execution/TrainStationExecution.ts b/src/core/execution/TrainStationExecution.ts index af8a5bd3f..5f029a0cc 100644 --- a/src/core/execution/TrainStationExecution.ts +++ b/src/core/execution/TrainStationExecution.ts @@ -45,7 +45,9 @@ export class TrainStationExecution implements Execution { this.active = false; return; } - this.spawnTrain(this.station, ticks); + if (this.spawnTrains) { + this.spawnTrain(this.station, ticks); + } } private shouldSpawnTrain(): boolean { @@ -69,8 +71,8 @@ export class TrainStationExecution implements Execution { if (cluster === null) { return; } - const availableForTrade = cluster.availableForTrade(this.unit.owner()); - if (availableForTrade.size === 0) { + const owner = this.unit.owner(); + if (!cluster.hasAnyTradeDestination(owner)) { return; } if (!this.shouldSpawnTrain()) { @@ -79,20 +81,20 @@ export class TrainStationExecution implements Execution { // Pick a destination randomly. // Could be improved to pick a lucrative trip - const destination: TrainStation = - this.random.randFromSet(availableForTrade); - if (destination !== station) { - this.mg.addExecution( - new TrainExecution( - this.mg.railNetwork(), - this.unit.owner(), - station, - destination, - this.numCars, - ), - ); - this.lastSpawnTick = currentTick; - } + const destination = cluster.randomTradeDestination(owner, this.random); + if (destination === null) return; + if (destination === station) return; + + this.mg.addExecution( + new TrainExecution( + this.mg.railNetwork(), + owner, + station, + destination, + this.numCars, + ), + ); + this.lastSpawnTick = currentTick; } activeDuringSpawnPhase(): boolean { diff --git a/src/core/game/TrainStation.ts b/src/core/game/TrainStation.ts index 9917737d9..6e9b57f0e 100644 --- a/src/core/game/TrainStation.ts +++ b/src/core/game/TrainStation.ts @@ -155,6 +155,12 @@ export class TrainStation { */ export class Cluster { public stations: Set = new Set(); + private tradeStations: Set = new Set(); + + private isTradeStation(station: TrainStation): boolean { + const type = station.unit.type(); + return type === UnitType.City || type === UnitType.Port; + } has(station: TrainStation) { return this.stations.has(station); @@ -162,11 +168,15 @@ export class Cluster { addStation(station: TrainStation) { this.stations.add(station); + if (this.isTradeStation(station)) { + this.tradeStations.add(station); + } station.setCluster(this); } removeStation(station: TrainStation) { this.stations.delete(station); + this.tradeStations.delete(station); } addStations(stations: Set) { @@ -181,14 +191,39 @@ export class Cluster { } } + hasAnyTradeDestination(player: Player): boolean { + for (const station of this.tradeStations) { + if (station.tradeAvailable(player)) { + return true; + } + } + return false; + } + + randomTradeDestination( + player: Player, + random: PseudoRandom, + ): TrainStation | null { + let selected: TrainStation | null = null; + let eligibleSeen = 0; + + for (const station of this.tradeStations) { + if (!station.tradeAvailable(player)) continue; + eligibleSeen++; + + // Reservoir sampling: keep each eligible station with probability 1/eligibleSeen. + if (random.nextInt(0, eligibleSeen) === 0) { + selected = station; + } + } + + return selected; + } + availableForTrade(player: Player): Set { const tradingStations = new Set(); - for (const station of this.stations) { - if ( - (station.unit.type() === UnitType.City || - station.unit.type() === UnitType.Port) && - station.tradeAvailable(player) - ) { + for (const station of this.tradeStations) { + if (station.tradeAvailable(player)) { tradingStations.add(station); } } @@ -201,6 +236,7 @@ export class Cluster { clear() { this.stations.clear(); + this.tradeStations.clear(); } } diff --git a/tests/core/game/Cluster.test.ts b/tests/core/game/Cluster.test.ts index 138e3968e..f64d73d09 100644 --- a/tests/core/game/Cluster.test.ts +++ b/tests/core/game/Cluster.test.ts @@ -1,9 +1,13 @@ import { vi, type Mocked } from "vitest"; +import { UnitType } from "../../../src/core/game/Game"; import { Cluster, TrainStation } from "../../../src/core/game/TrainStation"; const createMockStation = (id: string): Mocked => { return { id, + unit: { + type: vi.fn(() => UnitType.City), + } as any, setCluster: vi.fn(), getCluster: vi.fn(() => null), } as any; diff --git a/tests/core/game/RailNetwork.test.ts b/tests/core/game/RailNetwork.test.ts index fc39db81d..70be4febd 100644 --- a/tests/core/game/RailNetwork.test.ts +++ b/tests/core/game/RailNetwork.test.ts @@ -1,4 +1,4 @@ -import { Unit } from "../../../src/core/game/Game"; +import { Unit, UnitType } from "../../../src/core/game/Game"; import { RailNetworkImpl, StationManagerImpl, @@ -14,6 +14,7 @@ const createMockStation = (unitId: number): any => { unit: { id: unitId, setTrainStation: vi.fn(), + type: vi.fn(() => UnitType.City), }, tile: vi.fn(), neighbors: vi.fn(() => []),