471 lines
14 KiB
TypeScript
471 lines
14 KiB
TypeScript
/**
|
|
* Semantic Search Service
|
|
*
|
|
* Unified hybrid search combining:
|
|
* 1. PostgreSQL full-text search (tsvector / tsquery) via GIN index
|
|
* 2. pgvector cosine-distance nearest-neighbor search via HNSW index
|
|
* 3. Reciprocal Rank Fusion (RRF) for final ranking
|
|
*
|
|
* All vector operations happen in the database — no JS cosine-similarity loops.
|
|
* All SQL uses parameterized queries ($1, $2…) via Prisma $queryRawUnsafe with bind params.
|
|
*/
|
|
|
|
import { embeddingService } from './embedding.service'
|
|
import { prisma } from '@/lib/prisma'
|
|
import { auth } from '@/auth'
|
|
|
|
export interface SearchResult {
|
|
noteId: string
|
|
title: string | null
|
|
content: string
|
|
score: number
|
|
matchType: 'exact' | 'related'
|
|
language?: string | null
|
|
}
|
|
|
|
export interface SearchOptions {
|
|
limit?: number
|
|
threshold?: number
|
|
includeExactMatches?: boolean
|
|
notebookId?: string
|
|
defaultTitle?: string
|
|
}
|
|
|
|
/** Validate that a value looks like a CUID/UUID (alphanumeric + dashes). */
|
|
function assertSafeId(value: string, label: string): void {
|
|
if (!/^[a-zA-Z0-9_-]+$/.test(value)) {
|
|
throw new Error(`Invalid ${label} format`)
|
|
}
|
|
}
|
|
|
|
export class SemanticSearchService {
|
|
private readonly RRF_K = 60
|
|
private readonly DEFAULT_LIMIT = 20
|
|
private readonly DEFAULT_THRESHOLD = 0.3
|
|
private readonly VECTOR_CANDIDATES = 50
|
|
private readonly FTS_CANDIDATES = 50
|
|
|
|
/**
|
|
* Hybrid search: FTS + pgvector with RRF fusion.
|
|
* Accepts an optional userId to skip auth() (used by agent tools).
|
|
*/
|
|
async search(
|
|
query: string,
|
|
options: SearchOptions = {}
|
|
): Promise<SearchResult[]> {
|
|
const {
|
|
limit = this.DEFAULT_LIMIT,
|
|
threshold = this.DEFAULT_THRESHOLD,
|
|
notebookId,
|
|
defaultTitle = 'Untitled'
|
|
} = options
|
|
|
|
if (!query || query.trim().length < 2) return []
|
|
|
|
const session = await auth()
|
|
const userId = session?.user?.id || null
|
|
return this._doSearch(query, userId, { limit, threshold, notebookId, defaultTitle })
|
|
}
|
|
|
|
/**
|
|
* Search as a specific user (no auth() call).
|
|
* Used by agent tools that run server-side without HTTP session.
|
|
*/
|
|
async searchAsUser(
|
|
userId: string,
|
|
query: string,
|
|
options: SearchOptions = {}
|
|
): Promise<SearchResult[]> {
|
|
const {
|
|
limit = this.DEFAULT_LIMIT,
|
|
threshold = this.DEFAULT_THRESHOLD,
|
|
notebookId,
|
|
defaultTitle = 'Untitled'
|
|
} = options
|
|
|
|
if (!query || query.trim().length < 2) return []
|
|
return this._doSearch(query, userId, { limit, threshold, notebookId, defaultTitle })
|
|
}
|
|
|
|
private async _doSearch(
|
|
query: string,
|
|
userId: string | null,
|
|
opts: { limit: number; threshold: number; notebookId?: string; defaultTitle: string }
|
|
): Promise<SearchResult[]> {
|
|
try {
|
|
const [keywordResults, semanticResults] = await Promise.all([
|
|
this.ftsSearch(query, userId, opts.notebookId),
|
|
this.vectorSearch(query, userId, opts.threshold, opts.notebookId)
|
|
])
|
|
|
|
const fusedResults = await this.reciprocalRankFusion(keywordResults, semanticResults)
|
|
|
|
return fusedResults
|
|
.sort((a, b) => b.score - a.score)
|
|
.slice(0, opts.limit)
|
|
.map(result => ({
|
|
...result,
|
|
title: result.title || opts.defaultTitle,
|
|
matchType: result.score > 0.8 ? 'exact' as const : 'related' as const
|
|
}))
|
|
} catch (error) {
|
|
console.error('Error in hybrid search:', error)
|
|
return this._ftsFallback(query, userId, opts)
|
|
}
|
|
}
|
|
|
|
/**
|
|
* PostgreSQL full-text search using tsvector + GIN index.
|
|
* SECURITY: Uses $queryRawUnsafe with parameterized bind params ($1, $2…).
|
|
* All IDs are validated via assertSafeId() before inclusion.
|
|
*/
|
|
private async ftsSearch(
|
|
query: string,
|
|
userId: string | null,
|
|
notebookId?: string
|
|
): Promise<Array<{ noteId: string; rank: number }>> {
|
|
// Validate IDs before any SQL construction
|
|
if (userId) assertSafeId(userId, 'userId')
|
|
if (notebookId) assertSafeId(notebookId, 'notebookId')
|
|
|
|
const params: any[] = []
|
|
let paramIdx = 0
|
|
|
|
// Bind search query (used twice in the query)
|
|
params.push(query)
|
|
const queryParam1 = `$${++paramIdx}`
|
|
const queryParam2 = `$${++paramIdx}`
|
|
params.push(query) // second usage
|
|
|
|
// User filter
|
|
const userClause = userId ? `AND "userId" = $${++paramIdx}` : ''
|
|
if (userId) params.push(userId)
|
|
|
|
// Notebook filter
|
|
let notebookClause = ''
|
|
if (notebookId !== undefined) {
|
|
if (notebookId) {
|
|
notebookClause = `AND "notebookId" = $${++paramIdx}`
|
|
params.push(notebookId)
|
|
} else {
|
|
notebookClause = `AND "notebookId" IS NULL`
|
|
}
|
|
}
|
|
|
|
// Limit
|
|
params.push(this.FTS_CANDIDATES)
|
|
const limitParam = `$${++paramIdx}`
|
|
|
|
const rows: Array<{ noteId: string; rank: number }> = await prisma.$queryRawUnsafe(
|
|
`SELECT id AS "noteId",
|
|
ts_rank("tsv", plainto_tsquery('simple', ${queryParam1})) AS rank
|
|
FROM "Note"
|
|
WHERE "tsv" @@ plainto_tsquery('simple', ${queryParam2})
|
|
AND "trashedAt" IS NULL
|
|
AND "isArchived" = false
|
|
${userClause}
|
|
${notebookClause}
|
|
ORDER BY rank DESC
|
|
LIMIT ${limitParam}`,
|
|
...params
|
|
)
|
|
|
|
return rows.map((r, i) => ({
|
|
noteId: r.noteId,
|
|
rank: i + 1
|
|
}))
|
|
}
|
|
|
|
/**
|
|
* pgvector cosine-distance search using the HNSW index.
|
|
* SECURITY: Uses $queryRawUnsafe with parameterized bind params ($1, $2…).
|
|
* userId/notebookId validated via assertSafeId().
|
|
* vecStr is internally generated from the embedding model output.
|
|
*/
|
|
private async vectorSearch(
|
|
query: string,
|
|
userId: string | null,
|
|
threshold: number,
|
|
notebookId?: string
|
|
): Promise<Array<{ noteId: string; rank: number }>> {
|
|
let queryEmbedding: number[]
|
|
try {
|
|
const result = await embeddingService.generateEmbedding(query)
|
|
queryEmbedding = result.embedding
|
|
} catch (error) {
|
|
console.error('Failed to generate query embedding:', error)
|
|
return []
|
|
}
|
|
|
|
// Validate IDs
|
|
if (userId) assertSafeId(userId, 'userId')
|
|
if (notebookId) assertSafeId(notebookId, 'notebookId')
|
|
|
|
const vecStr = embeddingService.toVectorString(queryEmbedding)
|
|
|
|
const params: any[] = []
|
|
let paramIdx = 0
|
|
|
|
// Vector parameter (used multiple times in the query)
|
|
params.push(vecStr)
|
|
const vecParam = `$${++paramIdx}::vector`
|
|
|
|
// User filter
|
|
const userClause = userId ? `AND n."userId" = $${++paramIdx}` : ''
|
|
if (userId) params.push(userId)
|
|
|
|
// Notebook filter
|
|
let notebookClause = ''
|
|
if (notebookId !== undefined) {
|
|
if (notebookId) {
|
|
notebookClause = `AND n."notebookId" = $${++paramIdx}`
|
|
params.push(notebookId)
|
|
} else {
|
|
notebookClause = `AND n."notebookId" IS NULL`
|
|
}
|
|
}
|
|
|
|
// Threshold
|
|
params.push(threshold)
|
|
const thresholdParam = `$${++paramIdx}`
|
|
|
|
// Limit
|
|
params.push(this.VECTOR_CANDIDATES)
|
|
const limitParam = `$${++paramIdx}`
|
|
|
|
const rows: Array<{ noteId: string; similarity: number }> = await prisma.$queryRawUnsafe(
|
|
`SELECT n.id AS "noteId",
|
|
1 - (e."embedding"::vector <=> ${vecParam}) AS similarity
|
|
FROM "Note" n
|
|
INNER JOIN "NoteEmbedding" e ON e."noteId" = n.id
|
|
WHERE n."trashedAt" IS NULL
|
|
AND n."isArchived" = false
|
|
${userClause}
|
|
${notebookClause}
|
|
AND 1 - (e."embedding"::vector <=> ${vecParam}) >= ${thresholdParam}
|
|
ORDER BY e."embedding"::vector <=> ${vecParam} ASC
|
|
LIMIT ${limitParam}`,
|
|
...params
|
|
)
|
|
|
|
return rows.map((r, i) => ({
|
|
noteId: r.noteId,
|
|
rank: i + 1
|
|
}))
|
|
}
|
|
|
|
/**
|
|
* Reciprocal Rank Fusion algorithm.
|
|
* Combines keyword and semantic ranked lists into a single ranking.
|
|
*/
|
|
private async reciprocalRankFusion(
|
|
keywordResults: Array<{ noteId: string; rank: number }>,
|
|
semanticResults: Array<{ noteId: string; rank: number }>
|
|
): Promise<SearchResult[]> {
|
|
const scores = new Map<string, number>()
|
|
|
|
for (const result of keywordResults) {
|
|
const rrfScore = 1 / (this.RRF_K + result.rank)
|
|
scores.set(result.noteId, (scores.get(result.noteId) || 0) + rrfScore)
|
|
}
|
|
|
|
for (const result of semanticResults) {
|
|
const rrfScore = 1 / (this.RRF_K + result.rank)
|
|
scores.set(result.noteId, (scores.get(result.noteId) || 0) + rrfScore)
|
|
}
|
|
|
|
const noteIds = Array.from(scores.keys())
|
|
if (noteIds.length === 0) return []
|
|
|
|
const notes = await prisma.note.findMany({
|
|
where: { id: { in: noteIds }, trashedAt: null },
|
|
select: {
|
|
id: true,
|
|
title: true,
|
|
content: true,
|
|
language: true
|
|
}
|
|
})
|
|
|
|
return notes.map(note => ({
|
|
noteId: note.id,
|
|
title: note.title,
|
|
content: note.content,
|
|
score: scores.get(note.id) || 0,
|
|
matchType: 'related' as const,
|
|
language: note.language
|
|
}))
|
|
}
|
|
|
|
/**
|
|
* Fallback to FTS-only when vector search fails entirely.
|
|
*/
|
|
private async _ftsFallback(
|
|
query: string,
|
|
userId: string | null,
|
|
opts: { limit: number; threshold: number; notebookId?: string; defaultTitle: string }
|
|
): Promise<SearchResult[]> {
|
|
try {
|
|
const keywordResults = await this.ftsSearch(query, userId, opts.notebookId)
|
|
const noteIds = keywordResults.slice(0, opts.limit).map(r => r.noteId)
|
|
const notes = await prisma.note.findMany({
|
|
where: { id: { in: noteIds }, trashedAt: null },
|
|
select: { id: true, title: true, content: true, language: true }
|
|
})
|
|
|
|
return notes.map(note => ({
|
|
noteId: note.id,
|
|
title: note.title || opts.defaultTitle,
|
|
content: note.content,
|
|
score: 1.0,
|
|
matchType: 'related' as const,
|
|
language: note.language
|
|
}))
|
|
} catch {
|
|
return []
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generate or update embedding for a note.
|
|
* Stores as native pgvector via raw SQL.
|
|
*
|
|
* SECURITY: Uses parameterized bind params ($1, $2).
|
|
* noteId validated via assertSafeId().
|
|
*/
|
|
async indexNote(noteId: string, options?: { force?: boolean }): Promise<void> {
|
|
try {
|
|
assertSafeId(noteId, 'noteId')
|
|
|
|
const note = await prisma.note.findUnique({
|
|
where: { id: noteId },
|
|
select: { content: true, title: true, lastAiAnalysis: true, sourceUrl: true }
|
|
})
|
|
|
|
if (!note?.content?.trim()) return
|
|
|
|
const shouldRegenerate = embeddingService.shouldRegenerateEmbedding(
|
|
note.content,
|
|
null,
|
|
note.lastAiAnalysis,
|
|
{ force: options?.force, isClip: Boolean(note.sourceUrl?.trim()) },
|
|
)
|
|
|
|
if (!shouldRegenerate) return
|
|
|
|
const { embedding } = await embeddingService.generateNoteEmbedding(note.title, note.content)
|
|
const vecStr = embeddingService.toVectorString(embedding)
|
|
|
|
await prisma.$queryRawUnsafe(
|
|
`INSERT INTO "NoteEmbedding" ("id", "noteId", "embedding", "createdAt")
|
|
VALUES (gen_random_uuid(), $1, $2::vector, now())
|
|
ON CONFLICT ("noteId")
|
|
DO UPDATE SET "embedding" = $2::vector`,
|
|
noteId,
|
|
vecStr
|
|
)
|
|
|
|
await prisma.note.update({
|
|
where: { id: noteId },
|
|
data: { lastAiAnalysis: new Date() }
|
|
})
|
|
} catch (error) {
|
|
console.error(`Error indexing note ${noteId}:`, error)
|
|
throw error
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Batch index multiple notes.
|
|
*/
|
|
async indexBatchNotes(noteIds: string[]): Promise<void> {
|
|
const BATCH_SIZE = 20
|
|
|
|
for (let i = 0; i < noteIds.length; i += BATCH_SIZE) {
|
|
const batch = noteIds.slice(i, i + BATCH_SIZE)
|
|
await Promise.allSettled(batch.map(noteId => this.indexNote(noteId)))
|
|
}
|
|
}
|
|
|
|
async searchWithDocuments(
|
|
userId: string,
|
|
query: string,
|
|
options?: SearchOptions & { noteId?: string; includeDocuments?: boolean }
|
|
): Promise<(SearchResult & { source?: 'note' | 'document'; pageNumber?: number; fileName?: string })[]> {
|
|
const includeDocuments = options?.includeDocuments !== false
|
|
const noteResults = await this.searchAsUser(userId, query, options)
|
|
|
|
if (!includeDocuments) return noteResults
|
|
|
|
const queryEmbedding = await embeddingService.generateEmbedding(query)
|
|
const vectorStr = embeddingService.toVectorString(queryEmbedding.embedding)
|
|
|
|
let noteFilter = ''
|
|
const params: any[] = [vectorStr, 50, userId]
|
|
|
|
if (options?.noteId) {
|
|
assertSafeId(options.noteId, 'noteId')
|
|
params.push(options.noteId)
|
|
noteFilter = `AND na."noteId" = $${params.length}`
|
|
} else if (options?.notebookId) {
|
|
assertSafeId(options.notebookId, 'notebookId')
|
|
params.push(options.notebookId)
|
|
noteFilter = `AND n."notebookId" = $${params.length}`
|
|
}
|
|
|
|
const documentResults = await prisma.$queryRawUnsafe(
|
|
`SELECT
|
|
dc.content,
|
|
dc."pageNumber",
|
|
na."fileName",
|
|
na."noteId",
|
|
n.title as "noteTitle"
|
|
FROM "DocumentChunk" dc
|
|
JOIN "NoteAttachment" na ON na.id = dc."attachmentId"
|
|
JOIN "Note" n ON n.id = na."noteId"
|
|
WHERE dc."embedding" IS NOT NULL
|
|
AND na.status = 'ready'
|
|
AND n."trashedAt" IS NULL
|
|
AND n."userId" = $3
|
|
${noteFilter}
|
|
ORDER BY dc."embedding"::vector <=> $1::vector
|
|
LIMIT $2`,
|
|
...params
|
|
) as any[]
|
|
|
|
const K = 60
|
|
const fused = new Map<string, any>()
|
|
|
|
for (let i = 0; i < noteResults.length; i++) {
|
|
const r = noteResults[i]
|
|
fused.set(r.noteId, {
|
|
...r,
|
|
source: 'note',
|
|
rrfScore: 1 / (K + i + 1),
|
|
})
|
|
}
|
|
|
|
for (let i = 0; i < documentResults.length; i++) {
|
|
const r = documentResults[i]
|
|
const key = `doc_${r.noteId}_${r.pageNumber}_${i}`
|
|
fused.set(key, {
|
|
noteId: r.noteId,
|
|
title: `${r.noteTitle || 'Untitled'} → ${r.fileName} (p.${r.pageNumber})`,
|
|
content: r.content.substring(0, 500),
|
|
score: 0.5,
|
|
matchType: 'related' as const,
|
|
source: 'document',
|
|
pageNumber: r.pageNumber,
|
|
fileName: r.fileName,
|
|
rrfScore: 1 / (K + i + 1),
|
|
})
|
|
}
|
|
|
|
return Array.from(fused.values())
|
|
.sort((a, b) => b.rrfScore - a.rrfScore)
|
|
.slice(0, options?.limit || 20)
|
|
}
|
|
}
|
|
|
|
export const semanticSearchService = new SemanticSearchService()
|