security: fix SQL injection in semantic search - use parameterized queries with bind params
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 5s
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 5s
- Replace string interpolation in $queryRawUnsafe with bind params ($1, $2...) - Add assertSafeId() validation for userId, notebookId, noteId - ftsSearch: bind query, userId, notebookId as parameters - vectorSearch: bind vector string, userId, notebookId, threshold as parameters - indexNote: already used bind params, added noteId validation - Fixes CRITICAL security audit finding #1
This commit is contained in:
@@ -7,6 +7,7 @@
|
|||||||
* 3. Reciprocal Rank Fusion (RRF) for final ranking
|
* 3. Reciprocal Rank Fusion (RRF) for final ranking
|
||||||
*
|
*
|
||||||
* All vector operations happen in the database — no JS cosine-similarity loops.
|
* All vector operations happen in the database — no JS cosine-similarity loops.
|
||||||
|
* All SQL uses parameterized queries ($1, $2…) via Prisma $queryRawUnsafe with bind params.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { embeddingService } from './embedding.service'
|
import { embeddingService } from './embedding.service'
|
||||||
@@ -30,6 +31,13 @@ export interface SearchOptions {
|
|||||||
defaultTitle?: string
|
defaultTitle?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Validate that a value looks like a CUID/UUID (alphanumeric + dashes). */
|
||||||
|
function assertSafeId(value: string, label: string): void {
|
||||||
|
if (!/^[a-zA-Z0-9_-]+$/.test(value)) {
|
||||||
|
throw new Error(`Invalid ${label} format`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export class SemanticSearchService {
|
export class SemanticSearchService {
|
||||||
private readonly RRF_K = 60
|
private readonly RRF_K = 60
|
||||||
private readonly DEFAULT_LIMIT = 20
|
private readonly DEFAULT_LIMIT = 20
|
||||||
@@ -108,35 +116,60 @@ export class SemanticSearchService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* PostgreSQL full-text search using tsvector + GIN index.
|
* PostgreSQL full-text search using tsvector + GIN index.
|
||||||
* Returns ranked results using ts_rank.
|
* SECURITY: Uses $queryRawUnsafe with parameterized bind params ($1, $2…).
|
||||||
|
* All IDs are validated via assertSafeId() before inclusion.
|
||||||
*/
|
*/
|
||||||
private async ftsSearch(
|
private async ftsSearch(
|
||||||
query: string,
|
query: string,
|
||||||
userId: string | null,
|
userId: string | null,
|
||||||
notebookId?: string
|
notebookId?: string
|
||||||
): Promise<Array<{ noteId: string; rank: number }>> {
|
): Promise<Array<{ noteId: string; rank: number }>> {
|
||||||
const safeQuery = query.replace(/'/g, "''")
|
// Validate IDs before any SQL construction
|
||||||
|
if (userId) assertSafeId(userId, 'userId')
|
||||||
|
if (notebookId) assertSafeId(notebookId, 'notebookId')
|
||||||
|
|
||||||
const userClause = userId ? `AND "userId" = '${userId}'` : ''
|
const params: any[] = []
|
||||||
const notebookClause = notebookId !== undefined
|
let paramIdx = 0
|
||||||
? `AND "notebookId" ${notebookId ? `= '${notebookId.replace(/'/g, "''")}'` : 'IS NULL'}`
|
|
||||||
: ''
|
|
||||||
|
|
||||||
const sql = `
|
// Bind search query (used twice in the query)
|
||||||
SELECT id AS "noteId", ts_rank("tsv", plainto_tsquery('simple', '${safeQuery}')) AS rank
|
params.push(query)
|
||||||
FROM "Note"
|
const queryParam1 = `$${++paramIdx}`
|
||||||
WHERE "tsv" @@ plainto_tsquery('simple', '${safeQuery}')
|
const queryParam2 = `$${++paramIdx}`
|
||||||
AND "trashedAt" IS NULL
|
params.push(query) // second usage
|
||||||
AND "isArchived" = false
|
|
||||||
${userClause}
|
|
||||||
${notebookClause}
|
|
||||||
ORDER BY rank DESC
|
|
||||||
LIMIT ${this.FTS_CANDIDATES}
|
|
||||||
`
|
|
||||||
|
|
||||||
const rows: Array<{ noteId: string; rank: number }> = await prisma.$queryRawUnsafe(sql)
|
// User filter
|
||||||
|
const userClause = userId ? `AND "userId" = $${++paramIdx}` : ''
|
||||||
|
if (userId) params.push(userId)
|
||||||
|
|
||||||
|
// Notebook filter
|
||||||
|
let notebookClause = ''
|
||||||
|
if (notebookId !== undefined) {
|
||||||
|
if (notebookId) {
|
||||||
|
notebookClause = `AND "notebookId" = $${++paramIdx}`
|
||||||
|
params.push(notebookId)
|
||||||
|
} else {
|
||||||
|
notebookClause = `AND "notebookId" IS NULL`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit
|
||||||
|
params.push(this.FTS_CANDIDATES)
|
||||||
|
const limitParam = `$${++paramIdx}`
|
||||||
|
|
||||||
|
const rows: Array<{ noteId: string; rank: number }> = await prisma.$queryRawUnsafe(
|
||||||
|
`SELECT id AS "noteId",
|
||||||
|
ts_rank("tsv", plainto_tsquery('simple', ${queryParam1})) AS rank
|
||||||
|
FROM "Note"
|
||||||
|
WHERE "tsv" @@ plainto_tsquery('simple', ${queryParam2})
|
||||||
|
AND "trashedAt" IS NULL
|
||||||
|
AND "isArchived" = false
|
||||||
|
${userClause}
|
||||||
|
${notebookClause}
|
||||||
|
ORDER BY rank DESC
|
||||||
|
LIMIT ${limitParam}`,
|
||||||
|
...params
|
||||||
|
)
|
||||||
|
|
||||||
const maxRank = rows.length > 0 ? rows[0].rank : 1
|
|
||||||
return rows.map((r, i) => ({
|
return rows.map((r, i) => ({
|
||||||
noteId: r.noteId,
|
noteId: r.noteId,
|
||||||
rank: i + 1
|
rank: i + 1
|
||||||
@@ -145,7 +178,9 @@ export class SemanticSearchService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* pgvector cosine-distance search using the HNSW index.
|
* pgvector cosine-distance search using the HNSW index.
|
||||||
* Returns nearest neighbors above the similarity threshold.
|
* SECURITY: Uses $queryRawUnsafe with parameterized bind params ($1, $2…).
|
||||||
|
* userId/notebookId validated via assertSafeId().
|
||||||
|
* vecStr is internally generated from the embedding model output.
|
||||||
*/
|
*/
|
||||||
private async vectorSearch(
|
private async vectorSearch(
|
||||||
query: string,
|
query: string,
|
||||||
@@ -162,27 +197,56 @@ export class SemanticSearchService {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate IDs
|
||||||
|
if (userId) assertSafeId(userId, 'userId')
|
||||||
|
if (notebookId) assertSafeId(notebookId, 'notebookId')
|
||||||
|
|
||||||
const vecStr = embeddingService.toVectorString(queryEmbedding)
|
const vecStr = embeddingService.toVectorString(queryEmbedding)
|
||||||
const userClause = userId ? `AND n."userId" = '${userId}'` : ''
|
|
||||||
const notebookClause = notebookId !== undefined
|
|
||||||
? `AND n."notebookId" ${notebookId ? `= '${notebookId.replace(/'/g, "''")}'` : 'IS NULL'}`
|
|
||||||
: ''
|
|
||||||
|
|
||||||
const sql = `
|
const params: any[] = []
|
||||||
SELECT n.id AS "noteId",
|
let paramIdx = 0
|
||||||
1 - (e."embedding" <=> '${vecStr}'::vector) AS similarity
|
|
||||||
FROM "Note" n
|
|
||||||
INNER JOIN "NoteEmbedding" e ON e."noteId" = n.id
|
|
||||||
WHERE n."trashedAt" IS NULL
|
|
||||||
AND n."isArchived" = false
|
|
||||||
${userClause}
|
|
||||||
${notebookClause}
|
|
||||||
AND 1 - (e."embedding" <=> '${vecStr}'::vector) >= ${threshold}
|
|
||||||
ORDER BY e."embedding" <=> '${vecStr}'::vector ASC
|
|
||||||
LIMIT ${this.VECTOR_CANDIDATES}
|
|
||||||
`
|
|
||||||
|
|
||||||
const rows: Array<{ noteId: string; similarity: number }> = await prisma.$queryRawUnsafe(sql)
|
// Vector parameter (used multiple times in the query)
|
||||||
|
params.push(vecStr)
|
||||||
|
const vecParam = `$${++paramIdx}::vector`
|
||||||
|
|
||||||
|
// User filter
|
||||||
|
const userClause = userId ? `AND n."userId" = $${++paramIdx}` : ''
|
||||||
|
if (userId) params.push(userId)
|
||||||
|
|
||||||
|
// Notebook filter
|
||||||
|
let notebookClause = ''
|
||||||
|
if (notebookId !== undefined) {
|
||||||
|
if (notebookId) {
|
||||||
|
notebookClause = `AND n."notebookId" = $${++paramIdx}`
|
||||||
|
params.push(notebookId)
|
||||||
|
} else {
|
||||||
|
notebookClause = `AND n."notebookId" IS NULL`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threshold
|
||||||
|
params.push(threshold)
|
||||||
|
const thresholdParam = `$${++paramIdx}`
|
||||||
|
|
||||||
|
// Limit
|
||||||
|
params.push(this.VECTOR_CANDIDATES)
|
||||||
|
const limitParam = `$${++paramIdx}`
|
||||||
|
|
||||||
|
const rows: Array<{ noteId: string; similarity: number }> = await prisma.$queryRawUnsafe(
|
||||||
|
`SELECT n.id AS "noteId",
|
||||||
|
1 - (e."embedding" <=> ${vecParam}) AS similarity
|
||||||
|
FROM "Note" n
|
||||||
|
INNER JOIN "NoteEmbedding" e ON e."noteId" = n.id
|
||||||
|
WHERE n."trashedAt" IS NULL
|
||||||
|
AND n."isArchived" = false
|
||||||
|
${userClause}
|
||||||
|
${notebookClause}
|
||||||
|
AND 1 - (e."embedding" <=> ${vecParam}) >= ${thresholdParam}
|
||||||
|
ORDER BY e."embedding" <=> ${vecParam} ASC
|
||||||
|
LIMIT ${limitParam}`,
|
||||||
|
...params
|
||||||
|
)
|
||||||
|
|
||||||
return rows.map((r, i) => ({
|
return rows.map((r, i) => ({
|
||||||
noteId: r.noteId,
|
noteId: r.noteId,
|
||||||
@@ -265,9 +329,14 @@ export class SemanticSearchService {
|
|||||||
/**
|
/**
|
||||||
* Generate or update embedding for a note.
|
* Generate or update embedding for a note.
|
||||||
* Stores as native pgvector via raw SQL.
|
* Stores as native pgvector via raw SQL.
|
||||||
|
*
|
||||||
|
* SECURITY: Uses parameterized bind params ($1, $2).
|
||||||
|
* noteId validated via assertSafeId().
|
||||||
*/
|
*/
|
||||||
async indexNote(noteId: string): Promise<void> {
|
async indexNote(noteId: string): Promise<void> {
|
||||||
try {
|
try {
|
||||||
|
assertSafeId(noteId, 'noteId')
|
||||||
|
|
||||||
const note = await prisma.note.findUnique({
|
const note = await prisma.note.findUnique({
|
||||||
where: { id: noteId },
|
where: { id: noteId },
|
||||||
select: { content: true, lastAiAnalysis: true }
|
select: { content: true, lastAiAnalysis: true }
|
||||||
@@ -286,7 +355,7 @@ export class SemanticSearchService {
|
|||||||
const { embedding } = await embeddingService.generateEmbedding(note.content)
|
const { embedding } = await embeddingService.generateEmbedding(note.content)
|
||||||
const vecStr = embeddingService.toVectorString(embedding)
|
const vecStr = embeddingService.toVectorString(embedding)
|
||||||
|
|
||||||
await prisma.$executeRawUnsafe(
|
await prisma.$queryRawUnsafe(
|
||||||
`INSERT INTO "NoteEmbedding" ("id", "noteId", "embedding", "createdAt", "updatedAt")
|
`INSERT INTO "NoteEmbedding" ("id", "noteId", "embedding", "createdAt", "updatedAt")
|
||||||
VALUES (gen_random_uuid(), $1, $2::vector, now(), now())
|
VALUES (gen_random_uuid(), $1, $2::vector, now(), now())
|
||||||
ON CONFLICT ("noteId")
|
ON CONFLICT ("noteId")
|
||||||
|
|||||||
Reference in New Issue
Block a user