Files
Momento/memento-note/lib/ai/services/semantic-search.service.ts
Antigravity ff664f7523
All checks were successful
Deploy to Production / Build and Deploy (push) Successful in 5s
fix: add missing await on reciprocalRankFusion call
2026-05-12 10:53:34 +00:00

322 lines
9.4 KiB
TypeScript

/**
* Semantic Search Service
*
* Unified hybrid search combining:
* 1. PostgreSQL full-text search (tsvector / tsquery) via GIN index
* 2. pgvector cosine-distance nearest-neighbor search via HNSW index
* 3. Reciprocal Rank Fusion (RRF) for final ranking
*
* All vector operations happen in the database — no JS cosine-similarity loops.
*/
import { embeddingService } from './embedding.service'
import { prisma } from '@/lib/prisma'
import { auth } from '@/auth'
export interface SearchResult {
noteId: string
title: string | null
content: string
score: number
matchType: 'exact' | 'related'
language?: string | null
}
export interface SearchOptions {
limit?: number
threshold?: number
includeExactMatches?: boolean
notebookId?: string
defaultTitle?: string
}
export class SemanticSearchService {
private readonly RRF_K = 60
private readonly DEFAULT_LIMIT = 20
private readonly DEFAULT_THRESHOLD = 0.3
private readonly VECTOR_CANDIDATES = 50
private readonly FTS_CANDIDATES = 50
/**
* Hybrid search: FTS + pgvector with RRF fusion.
* Accepts an optional userId to skip auth() (used by agent tools).
*/
async search(
query: string,
options: SearchOptions = {}
): Promise<SearchResult[]> {
const {
limit = this.DEFAULT_LIMIT,
threshold = this.DEFAULT_THRESHOLD,
notebookId,
defaultTitle = 'Untitled'
} = options
if (!query || query.trim().length < 2) return []
const session = await auth()
const userId = session?.user?.id || null
return this._doSearch(query, userId, { limit, threshold, notebookId, defaultTitle })
}
/**
* Search as a specific user (no auth() call).
* Used by agent tools that run server-side without HTTP session.
*/
async searchAsUser(
userId: string,
query: string,
options: SearchOptions = {}
): Promise<SearchResult[]> {
const {
limit = this.DEFAULT_LIMIT,
threshold = this.DEFAULT_THRESHOLD,
notebookId,
defaultTitle = 'Untitled'
} = options
if (!query || query.trim().length < 2) return []
return this._doSearch(query, userId, { limit, threshold, notebookId, defaultTitle })
}
private async _doSearch(
query: string,
userId: string | null,
opts: { limit: number; threshold: number; notebookId?: string; defaultTitle: string }
): Promise<SearchResult[]> {
try {
const [keywordResults, semanticResults] = await Promise.all([
this.ftsSearch(query, userId, opts.notebookId),
this.vectorSearch(query, userId, opts.threshold, opts.notebookId)
])
const fusedResults = await this.reciprocalRankFusion(keywordResults, semanticResults)
return fusedResults
.sort((a, b) => b.score - a.score)
.slice(0, opts.limit)
.map(result => ({
...result,
title: result.title || opts.defaultTitle,
matchType: result.score > 0.8 ? 'exact' as const : 'related' as const
}))
} catch (error) {
console.error('Error in hybrid search:', error)
return this._ftsFallback(query, userId, opts)
}
}
/**
* PostgreSQL full-text search using tsvector + GIN index.
* Returns ranked results using ts_rank.
*/
private async ftsSearch(
query: string,
userId: string | null,
notebookId?: string
): Promise<Array<{ noteId: string; rank: number }>> {
const safeQuery = query.replace(/'/g, "''")
const userClause = userId ? `AND "userId" = '${userId}'` : ''
const notebookClause = notebookId !== undefined
? `AND "notebookId" ${notebookId ? `= '${notebookId.replace(/'/g, "''")}'` : 'IS NULL'}`
: ''
const sql = `
SELECT id AS "noteId", ts_rank("tsv", plainto_tsquery('simple', '${safeQuery}')) AS rank
FROM "Note"
WHERE "tsv" @@ plainto_tsquery('simple', '${safeQuery}')
AND "trashedAt" IS NULL
AND "isArchived" = false
${userClause}
${notebookClause}
ORDER BY rank DESC
LIMIT ${this.FTS_CANDIDATES}
`
const rows: Array<{ noteId: string; rank: number }> = await prisma.$queryRawUnsafe(sql)
const maxRank = rows.length > 0 ? rows[0].rank : 1
return rows.map((r, i) => ({
noteId: r.noteId,
rank: i + 1
}))
}
/**
* pgvector cosine-distance search using the HNSW index.
* Returns nearest neighbors above the similarity threshold.
*/
private async vectorSearch(
query: string,
userId: string | null,
threshold: number,
notebookId?: string
): Promise<Array<{ noteId: string; rank: number }>> {
let queryEmbedding: number[]
try {
const result = await embeddingService.generateEmbedding(query)
queryEmbedding = result.embedding
} catch (error) {
console.error('Failed to generate query embedding:', error)
return []
}
const vecStr = embeddingService.toVectorString(queryEmbedding)
const userClause = userId ? `AND n."userId" = '${userId}'` : ''
const notebookClause = notebookId !== undefined
? `AND n."notebookId" ${notebookId ? `= '${notebookId.replace(/'/g, "''")}'` : 'IS NULL'}`
: ''
const sql = `
SELECT n.id AS "noteId",
1 - (e."embedding" <=> '${vecStr}'::vector) AS similarity
FROM "Note" n
INNER JOIN "NoteEmbedding" e ON e."noteId" = n.id
WHERE n."trashedAt" IS NULL
AND n."isArchived" = false
${userClause}
${notebookClause}
AND 1 - (e."embedding" <=> '${vecStr}'::vector) >= ${threshold}
ORDER BY e."embedding" <=> '${vecStr}'::vector ASC
LIMIT ${this.VECTOR_CANDIDATES}
`
const rows: Array<{ noteId: string; similarity: number }> = await prisma.$queryRawUnsafe(sql)
return rows.map((r, i) => ({
noteId: r.noteId,
rank: i + 1
}))
}
/**
* Reciprocal Rank Fusion algorithm.
* Combines keyword and semantic ranked lists into a single ranking.
*/
private async reciprocalRankFusion(
keywordResults: Array<{ noteId: string; rank: number }>,
semanticResults: Array<{ noteId: string; rank: number }>
): Promise<SearchResult[]> {
const scores = new Map<string, number>()
for (const result of keywordResults) {
const rrfScore = 1 / (this.RRF_K + result.rank)
scores.set(result.noteId, (scores.get(result.noteId) || 0) + rrfScore)
}
for (const result of semanticResults) {
const rrfScore = 1 / (this.RRF_K + result.rank)
scores.set(result.noteId, (scores.get(result.noteId) || 0) + rrfScore)
}
const noteIds = Array.from(scores.keys())
if (noteIds.length === 0) return []
const notes = await prisma.note.findMany({
where: { id: { in: noteIds }, trashedAt: null },
select: {
id: true,
title: true,
content: true,
language: true
}
})
return notes.map(note => ({
noteId: note.id,
title: note.title,
content: note.content,
score: scores.get(note.id) || 0,
matchType: 'related' as const,
language: note.language
}))
}
/**
* Fallback to FTS-only when vector search fails entirely.
*/
private async _ftsFallback(
query: string,
userId: string | null,
opts: { limit: number; threshold: number; notebookId?: string; defaultTitle: string }
): Promise<SearchResult[]> {
try {
const keywordResults = await this.ftsSearch(query, userId, opts.notebookId)
const noteIds = keywordResults.slice(0, opts.limit).map(r => r.noteId)
const notes = await prisma.note.findMany({
where: { id: { in: noteIds }, trashedAt: null },
select: { id: true, title: true, content: true, language: true }
})
return notes.map(note => ({
noteId: note.id,
title: note.title || opts.defaultTitle,
content: note.content,
score: 1.0,
matchType: 'related' as const,
language: note.language
}))
} catch {
return []
}
}
/**
* Generate or update embedding for a note.
* Stores as native pgvector via raw SQL.
*/
async indexNote(noteId: string): Promise<void> {
try {
const note = await prisma.note.findUnique({
where: { id: noteId },
select: { content: true, lastAiAnalysis: true }
})
if (!note) throw new Error('Note not found')
const shouldRegenerate = embeddingService.shouldRegenerateEmbedding(
note.content,
null,
note.lastAiAnalysis
)
if (!shouldRegenerate) return
const { embedding } = await embeddingService.generateEmbedding(note.content)
const vecStr = embeddingService.toVectorString(embedding)
await prisma.$executeRawUnsafe(
`INSERT INTO "NoteEmbedding" ("id", "noteId", "embedding", "createdAt", "updatedAt")
VALUES (gen_random_uuid(), $1, $2::vector, now(), now())
ON CONFLICT ("noteId")
DO UPDATE SET "embedding" = $2::vector, "updatedAt" = now()`,
noteId,
vecStr
)
await prisma.note.update({
where: { id: noteId },
data: { lastAiAnalysis: new Date() }
})
} catch (error) {
console.error(`Error indexing note ${noteId}:`, error)
throw error
}
}
/**
* Batch index multiple notes.
*/
async indexBatchNotes(noteIds: string[]): Promise<void> {
const BATCH_SIZE = 20
for (let i = 0; i < noteIds.length; i += BATCH_SIZE) {
const batch = noteIds.slice(i, i + BATCH_SIZE)
await Promise.allSettled(batch.map(noteId => this.indexNote(noteId)))
}
}
}
export const semanticSearchService = new SemanticSearchService()