security: fix SQL injection in semantic search - use parameterized queries with bind params
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 5s
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 5s
- Replace string interpolation in $queryRawUnsafe with bind params ($1, $2...) - Add assertSafeId() validation for userId, notebookId, noteId - ftsSearch: bind query, userId, notebookId as parameters - vectorSearch: bind vector string, userId, notebookId, threshold as parameters - indexNote: already used bind params, added noteId validation - Fixes CRITICAL security audit finding #1
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
* 3. Reciprocal Rank Fusion (RRF) for final ranking
|
||||
*
|
||||
* All vector operations happen in the database — no JS cosine-similarity loops.
|
||||
* All SQL uses parameterized queries ($1, $2…) via Prisma $queryRawUnsafe with bind params.
|
||||
*/
|
||||
|
||||
import { embeddingService } from './embedding.service'
|
||||
@@ -30,6 +31,13 @@ export interface SearchOptions {
|
||||
defaultTitle?: string
|
||||
}
|
||||
|
||||
/** Validate that a value looks like a CUID/UUID (alphanumeric + dashes). */
|
||||
function assertSafeId(value: string, label: string): void {
|
||||
if (!/^[a-zA-Z0-9_-]+$/.test(value)) {
|
||||
throw new Error(`Invalid ${label} format`)
|
||||
}
|
||||
}
|
||||
|
||||
export class SemanticSearchService {
|
||||
private readonly RRF_K = 60
|
||||
private readonly DEFAULT_LIMIT = 20
|
||||
@@ -108,35 +116,60 @@ export class SemanticSearchService {
|
||||
|
||||
/**
|
||||
* PostgreSQL full-text search using tsvector + GIN index.
|
||||
* Returns ranked results using ts_rank.
|
||||
* SECURITY: Uses $queryRawUnsafe with parameterized bind params ($1, $2…).
|
||||
* All IDs are validated via assertSafeId() before inclusion.
|
||||
*/
|
||||
private async ftsSearch(
|
||||
query: string,
|
||||
userId: string | null,
|
||||
notebookId?: string
|
||||
): Promise<Array<{ noteId: string; rank: number }>> {
|
||||
const safeQuery = query.replace(/'/g, "''")
|
||||
// Validate IDs before any SQL construction
|
||||
if (userId) assertSafeId(userId, 'userId')
|
||||
if (notebookId) assertSafeId(notebookId, 'notebookId')
|
||||
|
||||
const userClause = userId ? `AND "userId" = '${userId}'` : ''
|
||||
const notebookClause = notebookId !== undefined
|
||||
? `AND "notebookId" ${notebookId ? `= '${notebookId.replace(/'/g, "''")}'` : 'IS NULL'}`
|
||||
: ''
|
||||
const params: any[] = []
|
||||
let paramIdx = 0
|
||||
|
||||
const sql = `
|
||||
SELECT id AS "noteId", ts_rank("tsv", plainto_tsquery('simple', '${safeQuery}')) AS rank
|
||||
// Bind search query (used twice in the query)
|
||||
params.push(query)
|
||||
const queryParam1 = `$${++paramIdx}`
|
||||
const queryParam2 = `$${++paramIdx}`
|
||||
params.push(query) // second usage
|
||||
|
||||
// User filter
|
||||
const userClause = userId ? `AND "userId" = $${++paramIdx}` : ''
|
||||
if (userId) params.push(userId)
|
||||
|
||||
// Notebook filter
|
||||
let notebookClause = ''
|
||||
if (notebookId !== undefined) {
|
||||
if (notebookId) {
|
||||
notebookClause = `AND "notebookId" = $${++paramIdx}`
|
||||
params.push(notebookId)
|
||||
} else {
|
||||
notebookClause = `AND "notebookId" IS NULL`
|
||||
}
|
||||
}
|
||||
|
||||
// Limit
|
||||
params.push(this.FTS_CANDIDATES)
|
||||
const limitParam = `$${++paramIdx}`
|
||||
|
||||
const rows: Array<{ noteId: string; rank: number }> = await prisma.$queryRawUnsafe(
|
||||
`SELECT id AS "noteId",
|
||||
ts_rank("tsv", plainto_tsquery('simple', ${queryParam1})) AS rank
|
||||
FROM "Note"
|
||||
WHERE "tsv" @@ plainto_tsquery('simple', '${safeQuery}')
|
||||
WHERE "tsv" @@ plainto_tsquery('simple', ${queryParam2})
|
||||
AND "trashedAt" IS NULL
|
||||
AND "isArchived" = false
|
||||
${userClause}
|
||||
${notebookClause}
|
||||
ORDER BY rank DESC
|
||||
LIMIT ${this.FTS_CANDIDATES}
|
||||
`
|
||||
LIMIT ${limitParam}`,
|
||||
...params
|
||||
)
|
||||
|
||||
const rows: Array<{ noteId: string; rank: number }> = await prisma.$queryRawUnsafe(sql)
|
||||
|
||||
const maxRank = rows.length > 0 ? rows[0].rank : 1
|
||||
return rows.map((r, i) => ({
|
||||
noteId: r.noteId,
|
||||
rank: i + 1
|
||||
@@ -145,7 +178,9 @@ export class SemanticSearchService {
|
||||
|
||||
/**
|
||||
* pgvector cosine-distance search using the HNSW index.
|
||||
* Returns nearest neighbors above the similarity threshold.
|
||||
* SECURITY: Uses $queryRawUnsafe with parameterized bind params ($1, $2…).
|
||||
* userId/notebookId validated via assertSafeId().
|
||||
* vecStr is internally generated from the embedding model output.
|
||||
*/
|
||||
private async vectorSearch(
|
||||
query: string,
|
||||
@@ -162,27 +197,56 @@ export class SemanticSearchService {
|
||||
return []
|
||||
}
|
||||
|
||||
const vecStr = embeddingService.toVectorString(queryEmbedding)
|
||||
const userClause = userId ? `AND n."userId" = '${userId}'` : ''
|
||||
const notebookClause = notebookId !== undefined
|
||||
? `AND n."notebookId" ${notebookId ? `= '${notebookId.replace(/'/g, "''")}'` : 'IS NULL'}`
|
||||
: ''
|
||||
// Validate IDs
|
||||
if (userId) assertSafeId(userId, 'userId')
|
||||
if (notebookId) assertSafeId(notebookId, 'notebookId')
|
||||
|
||||
const sql = `
|
||||
SELECT n.id AS "noteId",
|
||||
1 - (e."embedding" <=> '${vecStr}'::vector) AS similarity
|
||||
const vecStr = embeddingService.toVectorString(queryEmbedding)
|
||||
|
||||
const params: any[] = []
|
||||
let paramIdx = 0
|
||||
|
||||
// Vector parameter (used multiple times in the query)
|
||||
params.push(vecStr)
|
||||
const vecParam = `$${++paramIdx}::vector`
|
||||
|
||||
// User filter
|
||||
const userClause = userId ? `AND n."userId" = $${++paramIdx}` : ''
|
||||
if (userId) params.push(userId)
|
||||
|
||||
// Notebook filter
|
||||
let notebookClause = ''
|
||||
if (notebookId !== undefined) {
|
||||
if (notebookId) {
|
||||
notebookClause = `AND n."notebookId" = $${++paramIdx}`
|
||||
params.push(notebookId)
|
||||
} else {
|
||||
notebookClause = `AND n."notebookId" IS NULL`
|
||||
}
|
||||
}
|
||||
|
||||
// Threshold
|
||||
params.push(threshold)
|
||||
const thresholdParam = `$${++paramIdx}`
|
||||
|
||||
// Limit
|
||||
params.push(this.VECTOR_CANDIDATES)
|
||||
const limitParam = `$${++paramIdx}`
|
||||
|
||||
const rows: Array<{ noteId: string; similarity: number }> = await prisma.$queryRawUnsafe(
|
||||
`SELECT n.id AS "noteId",
|
||||
1 - (e."embedding" <=> ${vecParam}) AS similarity
|
||||
FROM "Note" n
|
||||
INNER JOIN "NoteEmbedding" e ON e."noteId" = n.id
|
||||
WHERE n."trashedAt" IS NULL
|
||||
AND n."isArchived" = false
|
||||
${userClause}
|
||||
${notebookClause}
|
||||
AND 1 - (e."embedding" <=> '${vecStr}'::vector) >= ${threshold}
|
||||
ORDER BY e."embedding" <=> '${vecStr}'::vector ASC
|
||||
LIMIT ${this.VECTOR_CANDIDATES}
|
||||
`
|
||||
|
||||
const rows: Array<{ noteId: string; similarity: number }> = await prisma.$queryRawUnsafe(sql)
|
||||
AND 1 - (e."embedding" <=> ${vecParam}) >= ${thresholdParam}
|
||||
ORDER BY e."embedding" <=> ${vecParam} ASC
|
||||
LIMIT ${limitParam}`,
|
||||
...params
|
||||
)
|
||||
|
||||
return rows.map((r, i) => ({
|
||||
noteId: r.noteId,
|
||||
@@ -265,9 +329,14 @@ export class SemanticSearchService {
|
||||
/**
|
||||
* Generate or update embedding for a note.
|
||||
* Stores as native pgvector via raw SQL.
|
||||
*
|
||||
* SECURITY: Uses parameterized bind params ($1, $2).
|
||||
* noteId validated via assertSafeId().
|
||||
*/
|
||||
async indexNote(noteId: string): Promise<void> {
|
||||
try {
|
||||
assertSafeId(noteId, 'noteId')
|
||||
|
||||
const note = await prisma.note.findUnique({
|
||||
where: { id: noteId },
|
||||
select: { content: true, lastAiAnalysis: true }
|
||||
@@ -286,7 +355,7 @@ export class SemanticSearchService {
|
||||
const { embedding } = await embeddingService.generateEmbedding(note.content)
|
||||
const vecStr = embeddingService.toVectorString(embedding)
|
||||
|
||||
await prisma.$executeRawUnsafe(
|
||||
await prisma.$queryRawUnsafe(
|
||||
`INSERT INTO "NoteEmbedding" ("id", "noteId", "embedding", "createdAt", "updatedAt")
|
||||
VALUES (gen_random_uuid(), $1, $2::vector, now(), now())
|
||||
ON CONFLICT ("noteId")
|
||||
|
||||
Reference in New Issue
Block a user