Files
Momento/mcp-server/index-sse.js
Antigravity 4d96605144
All checks were successful
CI / Lint, Unit Tests & Build (push) Successful in 5m42s
CI / Deploy production (on server) (push) Successful in 33s
fix(security): Phase 1 P0 hardening from cross-project audit
Close open uploads, image-proxy SSRF, fail-open AI quotas in production,
auth gaps on app routes, and MCP tenant isolation issues.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-20 16:53:19 +00:00

574 lines
21 KiB
JavaScript

#!/usr/bin/env node
/**
* Memento MCP Server - Streamable HTTP Transport (Enhanced)
*
* Features:
* - Prisma connection pooling
* - Compact JSON output
* - Bounded session cache
* - Proper keep-alive & timeouts
* - O(1) API key validation
* - Structured error handling
* - Observability metrics
* - Rate limiting
* - Input validation
* - Audit logging
*
* Environment:
* PORT Server port (default: 3001)
* DATABASE_URL Prisma database URL
* USER_ID Optional user ID filter
* APP_BASE_URL Next.js app URL (default: http://localhost:3000)
* MCP_REQUIRE_AUTH Set 'true' to require authentication
* MCP_API_KEY Static fallback API key
* MCP_LOG_LEVEL debug, info, warn, error (default: info)
* MCP_REQUEST_TIMEOUT Timeout in ms (default: 30000)
*/
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { PrismaClient } from '@prisma/client';
import { randomBytes } from 'crypto';
import express from 'express';
import cors from 'cors';
import { registerTools } from './tools.js';
import { validateApiKey } from './auth.js';
import { requestContext } from './request-context.js';
import config, { validateConfig, printConfig } from './config.js';
import {
mcpError,
mcpErrorContent,
McpErrors,
getErrorCategory,
withErrorHandling,
logError,
} from './errors.js';
import {
recordRequest,
recordError,
recordAuth,
recordDbQuery,
recordSession,
getPrometheusMetrics,
getMetricsSummary,
updateCacheSize,
} from './metrics.js';
import { combinedRateLimitMiddleware, getRateLimitStats } from './rate-limit.js';
import { validateAndSanitize, checkXSS } from './validation.js';
// ═══════════════════════════════════════════════════════════════
// Configuration Validation
// ═══════════════════════════════════════════════════════════════
const configErrors = validateConfig();
if (configErrors.some((e) => e.critical)) {
console.error('❌ CRITICAL CONFIGURATION ERRORS:');
configErrors.forEach((e) => console.error(` ${e.key}: ${e.message}`));
process.exit(1);
}
if (configErrors.length > 0) {
console.warn('⚠️ Configuration warnings:');
configErrors.forEach((e) => console.warn(` ${e.key}: ${e.message}`));
}
// ═══════════════════════════════════════════════════════════════
// Logging
// ═══════════════════════════════════════════════════════════════
const logLevels = { debug: 0, info: 1, warn: 2, error: 3, silent: 4 };
const currentLogLevel = logLevels[config.logLevel] ?? 1;
function log(level, ...args) {
if (logLevels[level] >= currentLogLevel) {
const timestamp = new Date().toISOString();
console.error(`[${timestamp}] [${level.toUpperCase()}]`, ...args);
}
}
// ═══════════════════════════════════════════════════════════════
// Database Setup
// ═══════════════════════════════════════════════════════════════
const databaseUrl = config.databaseUrl;
if (!databaseUrl) {
console.error('ERROR: DATABASE_URL is required');
process.exit(1);
}
const isPostgres = databaseUrl.startsWith('postgresql://') || databaseUrl.startsWith('postgres://');
const prisma = new PrismaClient({
datasources: { db: { url: databaseUrl } },
...(isPostgres
? {
datasources: {
db: {
url: `${databaseUrl}${databaseUrl.includes('?') ? '&' : '?'}connection_limit=${config.connectionLimit}&pool_timeout=${config.poolTimeout}`,
},
},
}
: {}),
log: config.logLevel === 'debug' ? ['query', 'info', 'warn', 'error'] : ['warn', 'error'],
});
// Wrap Prisma for metrics
const originalQuery = prisma.$queryRaw.bind(prisma);
prisma.$queryRaw = async (...args) => {
const start = Date.now();
try {
const result = await originalQuery(...args);
recordDbQuery(true, Date.now() - start);
return result;
} catch (error) {
recordDbQuery(false, Date.now() - start);
throw error;
}
};
const appBaseUrl = config.appBaseUrl;
// ═══════════════════════════════════════════════════════════════
// Bounded Session Cache
// ═══════════════════════════════════════════════════════════════
const sessions = new Map();
function cleanupSessions() {
const now = Date.now();
let cleaned = 0;
for (const [key, s] of sessions) {
if (now - s._lastSeen > config.sessionTtl) {
sessions.delete(key);
cleaned++;
}
}
if (cleaned > 0) {
log('debug', `Cleaned ${cleaned} expired sessions`);
recordSession('expire', cleaned);
}
updateCacheSize(sessions.size);
}
function pruneIfFull() {
if (sessions.size < config.maxSessions) return;
const entries = [...sessions.entries()].sort((a, b) => a[1]._lastSeen - b[1]._lastSeen);
for (let i = 0; i < Math.floor(config.maxSessions / 4); i++) {
sessions.delete(entries[i][0]);
}
}
setInterval(cleanupSessions, config.sessionCleanupInterval);
// ═══════════════════════════════════════════════════════════════
// Express App Setup
// ═══════════════════════════════════════════════════════════════
const app = express();
// CORS configuration
if (config.allowedOrigins.length > 0 && !config.allowedOrigins.includes('*')) {
app.use(
cors({
origin: config.allowedOrigins,
credentials: true,
})
);
} else {
app.use(cors());
}
app.use(express.json({ limit: config.maxRequestSize }));
// ═══════════════════════════════════════════════════════════════
// Request Logging Middleware
// ═══════════════════════════════════════════════════════════════
app.use((req, res, next) => {
const start = Date.now();
res.on('finish', () => {
const ms = Date.now() - start;
const sid = req.userSession?.id?.substring(0, 8) || 'anon';
log('debug', `[${sid}] ${req.method} ${req.path} ${res.statusCode} ${ms}ms`);
recordRequest('http', res.statusCode, req.method, ms);
});
next();
});
// ═══════════════════════════════════════════════════════════════
// Timeout Middleware
// ═══════════════════════════════════════════════════════════════
app.use((req, res, next) => {
req.setTimeout(config.requestTimeout);
res.setTimeout(config.requestTimeout, () => {
if (!res.headersSent) {
recordError(getErrorCategory(McpErrors.TIMEOUT.code), McpErrors.TIMEOUT.code);
res.status(408).json(mcpError(McpErrors.TIMEOUT.code));
}
});
next();
});
// ═══════════════════════════════════════════════════════════════
// Security Middleware (XSS Check)
// ═══════════════════════════════════════════════════════════════
app.use((req, res, next) => {
if (req.body && checkXSS(req.body)) {
recordError('xss', 'xss_detected', { path: req.path });
return res.status(400).json(mcpError(McpErrors.INVALID_PARAMS.code, {
detail: 'Request contains potentially malicious content',
}));
}
next();
});
// ═══════════════════════════════════════════════════════════════
// Rate Limiting Middleware
// ═══════════════════════════════════════════════════════════════
app.use(combinedRateLimitMiddleware);
// ═══════════════════════════════════════════════════════════════
// Health Endpoint (before auth - for Docker healthcheck)
// ═══════════════════════════════════════════════════════════════
app.get(config.healthPath, async (req, res) => {
try {
// Check database connection
await prisma.$queryRaw`SELECT 1`;
res.json({
ok: true,
uptime: process.uptime(),
timestamp: new Date().toISOString(),
metrics: getMetricsSummary(),
rateLimit: getRateLimitStats(),
sessions: {
active: sessions.size,
max: config.maxSessions,
},
});
} catch (error) {
res.status(503).json({
ok: false,
error: 'Database connection failed',
uptime: process.uptime(),
timestamp: new Date().toISOString(),
});
}
});
// ═══════════════════════════════════════════════════════════════
// Metrics Endpoint
// ═══════════════════════════════════════════════════════════════
if (config.enableMetrics) {
app.get(config.metricsPath, (req, res) => {
res.set('Content-Type', 'text/plain');
res.send(getPrometheusMetrics());
});
}
// ═══════════════════════════════════════════════════════════════
// Auth Middleware
// ═══════════════════════════════════════════════════════════════
app.use(
withErrorHandling(async (req, res, next) => {
if (!config.requireAuth) {
req.userSession = {
id: 'dev-user',
name: 'Development User',
isAuth: false,
userId: config.userId || null,
};
recordAuth(true, 'dev-mode');
return next();
}
const apiKey = req.headers['x-api-key'];
if (!apiKey) {
recordAuth(false, 'missing-credentials');
return res
.status(401)
.json(
mcpError(McpErrors.AUTH_FAILED.code, {
detail: 'Provide x-api-key header',
})
);
}
if (apiKey) {
const keyUser = await validateApiKey(prisma, apiKey);
if (keyUser) {
req.userSession = getOrCreateSession(
`key:${keyUser.apiKeyId}`,
{
name: `${keyUser.userName} (${keyUser.apiKeyName})`,
userId: keyUser.userId,
userName: keyUser.userName,
apiKeyId: keyUser.apiKeyId,
authMethod: 'api-key',
}
);
recordAuth(true, 'api-key');
return next();
}
if (config.staticApiKey && apiKey === config.staticApiKey) {
req.userSession = getOrCreateSession(`static:${apiKey.substring(0, 8)}`, {
name: 'Static API Key User',
userId: config.userId || null,
authMethod: 'static-key',
});
recordAuth(true, 'static-key');
return next();
}
recordAuth(false, 'invalid-api-key');
return res.status(401).json(mcpError(McpErrors.AUTH_FAILED.code, { detail: 'Invalid API key' }));
}
recordAuth(false, 'auth-failed');
return res.status(401).json(mcpError(McpErrors.AUTH_FAILED.code, { detail: 'Authentication failed' }));
})
);
function getOrCreateSession(key, base) {
const existing = sessions.get(key);
if (existing) {
existing._lastSeen = Date.now();
existing.requestCount = (existing.requestCount || 0) + 1;
return existing;
}
pruneIfFull();
const s = {
id: randomBytes(16).toString('hex'),
...base,
connectedAt: new Date().toISOString(),
requestCount: 1,
isAuth: true,
_lastSeen: Date.now(),
};
sessions.set(key, s);
recordSession('create');
return s;
}
// ═══════════════════════════════════════════════════════════════
// MCP Server Setup
// ═══════════════════════════════════════════════════════════════
const server = new Server(
{ name: 'memento-mcp-server', version: '3.2.0' },
{ capabilities: { tools: {} } },
);
registerTools(server, prisma);
// ═══════════════════════════════════════════════════════════════
// Routes
// ═══════════════════════════════════════════════════════════════
app.get('/', (req, res) => {
res.json({
name: 'Memento MCP Server',
version: '3.2.0',
status: 'running',
endpoints: {
mcp: '/mcp',
health: config.healthPath,
metrics: config.enableMetrics ? config.metricsPath : undefined,
sessions: '/sessions',
},
auth: { enabled: config.requireAuth },
tools: 22,
uptime: process.uptime(),
});
});
app.get('/sessions', (req, res) => {
const list = [...sessions.values()].map((s) => ({
id: s.id,
name: s.name,
connectedAt: s.connectedAt,
requestCount: s.requestCount || 0,
authMethod: s.authMethod,
}));
res.json({ activeUsers: list.length, sessions: list, uptime: process.uptime() });
});
// ═══════════════════════════════════════════════════════════════
// MCP Endpoint with Input Validation
// ═══════════════════════════════════════════════════════════════
const transports = {};
app.all(
'/mcp',
withErrorHandling(async (req, res) => {
const sessionId = req.headers['mcp-session-id'];
let transport;
if (sessionId && transports[sessionId]) {
transport = transports[sessionId];
} else {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomBytes(16).toString('hex'),
onsessioninitialized: (id) => {
log('debug', `Session init: ${id}`);
transports[id] = transport;
},
});
transport.onclose = () => {
const sid = transport.sessionId;
if (sid) {
log('debug', `Session close: ${sid}`);
delete transports[sid];
}
};
await server.connect(transport);
}
// Validate tool input if present
if (req.body?.method) {
const toolName = req.body.method;
if (req.body?.params) {
const validation = validateAndSanitize(toolName, req.body.params);
if (!validation.success) {
log('warn', `Validation failed for ${toolName}:`, validation.errors);
return res
.status(400)
.json(
mcpError(McpErrors.INVALID_PARAMS.code, {
detail: 'Input validation failed',
field: validation.errors[0]?.field,
context: { toolName, errors: validation.errors },
})
);
}
// Update request with sanitized data
req.body.params = validation.data;
}
}
const ctx = { userId: req.userSession?.userId || null };
await requestContext.run(ctx, async () => {
await transport.handleRequest(req, res, req.body);
});
})
);
// Legacy /sse → /mcp redirect
app.all('/sse', (req, res) => {
res.redirect(307, '/mcp');
});
// ═══════════════════════════════════════════════════════════════
// Debug Routes (only in development)
// ═══════════════════════════════════════════════════════════════
if (config.nodeEnv === 'development') {
app.get('/debug/config', (req, res) => {
const { getPublicConfig } = require('./config.js');
res.json({ config: getPublicConfig() });
});
app.get('/debug/sessions', (req, res) => {
const sessionList = [...sessions.entries()].map(([key, s]) => ({
key,
id: s.id,
name: s.name,
requestCount: s.requestCount || 0,
_lastSeen: s._lastSeen,
}));
res.json({ sessions: sessionList, total: sessions.size });
});
app.delete('/debug/sessions/:key', (req, res) => {
sessions.delete(req.params.key);
res.json({ ok: true });
});
app.post('/debug/sessions/clear', (req, res) => {
sessions.clear();
res.json({ ok: true });
});
}
// ═══════════════════════════════════════════════════════════════
// Start Server
// ═══════════════════════════════════════════════════════════════
async function main() {
try {
await prisma.$queryRaw`SELECT 1`;
} catch (error) {
console.error('FATAL: Database connection failed:', error.message);
process.exit(1);
}
// Print configuration
printConfig();
app.listen(config.port, '0.0.0.0', () => {
console.log(`
╔═══════════════════════════════════════════════════════╗
║ Memento MCP Server v3.2.0 (Enhanced) ║
║ Streamable HTTP Transport ║
╚═══════════════════════════════════════════════════════╝
Server: http://localhost:${config.port}
MCP: http://localhost:${config.port}/mcp
Health: http://localhost:${config.port}${config.healthPath}
Metrics: http://localhost:${config.port}${config.metricsPath}
Auth: ${config.requireAuth ? 'ENABLED' : 'DISABLED (dev)'}
Timeout: ${config.requestTimeout}ms
Database: ${isPostgres ? 'PostgreSQL' : 'SQLite'}
Tools: 22
Features: ${config.enableMetrics ? 'Metrics' : ''}${config.enableAuditLog ? ', Audit Log' : ''}
`);
});
}
main().catch((error) => {
console.error('Server error:', error);
process.exit(1);
});
// ═══════════════════════════════════════════════════════════════
// Shutdown Handler
// ═══════════════════════════════════════════════════════════════
async function shutdown() {
log('info', 'Shutting down...');
await prisma.$disconnect();
// Close all transports
for (const [id, transport] of Object.entries(transports)) {
try {
transport.close();
} catch (e) {
// Ignore errors during shutdown
}
}
process.exit(0);
}
process.on('SIGINT', shutdown);
process.on('SIGTERM', shutdown);
process.on('uncaughtException', (err) => {
logError(log, err);
process.exit(1);
});
process.on('unhandledRejection', (reason) => {
log('error', 'Unhandled rejection:', reason);
recordError(getErrorCategory(McpErrors.INTERNAL_ERROR.code), 'unhandled_rejection');
});