import { Injectable, Logger, Inject, forwardRef } from '@nestjs/common'; import { ConfigService } from '@nestjs/config'; import { ElasticsearchService } from '../elasticsearch/elasticsearch.service'; import { EmbeddingService } from '../knowledge-base/embedding.service'; import { ModelConfigService } from '../model-config/model-config.service'; import { RerankService } from './rerank.service'; import { I18nService } from '../i18n/i18n.service'; import { TenantService } from '../tenant/tenant.service'; import { ChatOpenAI } from '@langchain/openai'; import { ModelConfig } from '../types'; import { UserSettingService } from '../user/user-setting.service'; export interface RagSearchResult { content: string; fileName: string; score: number; chunkIndex: number; fileId?: string; originalScore?: number; // Original score before reranking (for debugging) metadata?: any; } @Injectable() export class RagService { private readonly logger = new Logger(RagService.name); private readonly defaultDimensions: number; constructor( @Inject(forwardRef(() => ElasticsearchService)) private elasticsearchService: ElasticsearchService, private embeddingService: EmbeddingService, private modelConfigService: ModelConfigService, private rerankService: RerankService, private configService: ConfigService, private i18nService: I18nService, private tenantService: TenantService, private userSettingService: UserSettingService, ) { this.defaultDimensions = parseInt( this.configService.get('DEFAULT_VECTOR_DIMENSIONS', '2560'), ); this.logger.log(`RAG service default vector dimensions: ${this.defaultDimensions}`); } async searchKnowledge( query: string, userId: string, topK: number = 5, vectorSimilarityThreshold: number = 0.3, // Vector search similarity threshold embeddingModelId?: string, enableFullTextSearch: boolean = false, enableRerank: boolean = false, rerankModelId?: string, selectedGroups?: string[], effectiveFileIds?: string[], rerankSimilarityThreshold: number = 0.5, // Rerank similarity threshold (default 0.5) tenantId?: string, // New enableQueryExpansion?: boolean, enableHyDE?: boolean, ): Promise { // 1. Get organization settings const globalSettings = await this.tenantService.getSettings(tenantId || 'default'); // Use global settings if parameters are not explicitly provided const effectiveTopK = topK || globalSettings?.topK || 5; const effectiveVectorThreshold = vectorSimilarityThreshold !== undefined ? vectorSimilarityThreshold : (globalSettings?.similarityThreshold || 0.3); const effectiveRerankThreshold = rerankSimilarityThreshold !== undefined ? rerankSimilarityThreshold : (globalSettings?.rerankSimilarityThreshold || 0.5); const effectiveEnableRerank = enableRerank !== undefined ? enableRerank : (globalSettings?.enableRerank ?? false); const effectiveEnableFullText = enableFullTextSearch !== undefined ? enableFullTextSearch : (globalSettings?.enableFullTextSearch ?? false); const effectiveEmbeddingId = embeddingModelId || globalSettings?.selectedEmbeddingId; const effectiveRerankId = rerankModelId || globalSettings?.selectedRerankId; const effectiveHybridWeight = globalSettings?.hybridVectorWeight ?? 0.7; const effectiveEnableQueryExpansion = enableQueryExpansion !== undefined ? enableQueryExpansion : (globalSettings?.enableQueryExpansion ?? false); const effectiveEnableHyDE = enableHyDE !== undefined ? enableHyDE : (globalSettings?.enableHyDE ?? false); this.logger.log( `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}, QueryExpansion=${effectiveEnableQueryExpansion}, HyDE=${effectiveEnableHyDE}`, ); try { // 1. Prepare query (expansion or HyDE) let queriesToSearch = [query]; if (effectiveEnableHyDE) { const hydeDoc = await this.generateHyDE(query, userId); queriesToSearch = [hydeDoc]; // Use virtual document as query for HyDE } else if (effectiveEnableQueryExpansion) { const expanded = await this.expandQuery(query, userId); queriesToSearch = [...new Set([query, ...expanded])]; } // Check if embedding model ID is provided if (!effectiveEmbeddingId) { throw new Error(this.i18nService.getMessage('embeddingModelIdNotProvided')); } // 2. Parallel search for multiple queries const searchTasks = queriesToSearch.map(async (searchQuery) => { // Get query vector const queryEmbedding = await this.embeddingService.getEmbeddings( [searchQuery], userId, effectiveEmbeddingId, ); const queryVector = queryEmbedding[0]; // Select search strategy based on settings let results; if (effectiveEnableFullText) { results = await this.elasticsearchService.hybridSearch( queryVector, searchQuery, userId, effectiveTopK * 4, effectiveHybridWeight, undefined, effectiveFileIds, tenantId, ); } else { let vectorSearchResults = await this.elasticsearchService.searchSimilar( queryVector, userId, effectiveTopK * 4, tenantId, ); if (effectiveFileIds && effectiveFileIds.length > 0) { results = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId)); } else { results = vectorSearchResults; } } return results; }); const allResultsRaw = await Promise.all(searchTasks); let searchResults = this.deduplicateResults(allResultsRaw.flat()); // Initial similarity filtering const initialCount = searchResults.length; // Log output searchResults.forEach((r, idx) => { this.logger.log(`Hit ${idx}: score=${r.score.toFixed(4)}, fileName=${r.fileName}`); }); // Apply threshold filtering searchResults = searchResults.filter(r => r.score >= effectiveVectorThreshold); this.logger.log(`Initial hits: ${initialCount} -> filtered by vectorThreshold: ${searchResults.length}`); // 3. Rerank let finalResults = searchResults; if (effectiveEnableRerank && effectiveRerankId && searchResults.length > 0) { try { const docs = searchResults.map(r => r.content); const rerankedIndices = await this.rerankService.rerank( query, docs, userId, effectiveRerankId, effectiveTopK * 2 // Keep a bit more results ); finalResults = rerankedIndices.map(r => { const originalItem = searchResults[r.index]; return { ...originalItem, score: r.score, // Rerank score originalScore: originalItem.score // Original score }; }); // Filter after reranking const beforeRerankFilter = finalResults.length; finalResults = finalResults.filter(r => r.score >= effectiveRerankThreshold); this.logger.log(`After rerank: ${beforeRerankFilter} -> filtered by rerankThreshold: ${finalResults.length}`); } catch (error) { this.logger.warn(`Rerank failed, falling back to filtered vector search: ${error.message}`); // Fall back to filtered vector search results if rerank fails } } // Final result count limit finalResults = finalResults.slice(0, effectiveTopK); // 4. Convert to RAG result format const ragResults: RagSearchResult[] = finalResults.map((result) => ({ content: result.content, fileName: result.fileName, score: result.score, originalScore: result.originalScore !== undefined ? result.originalScore : result.score, chunkIndex: result.chunkIndex, fileId: result.fileId, metadata: result.metadata, })); return ragResults; } catch (error) { this.logger.error('RAG search failed:', error); return []; } } buildRagPrompt( query: string, searchResults: RagSearchResult[], language: string = 'ja', ): string { const lang = language || 'ja'; // Build context let context = ''; if (searchResults.length === 0) { context = this.i18nService.getMessage('ragNoDocumentFound', lang); } else { // Group by file const fileGroups = new Map(); searchResults.forEach((result) => { if (!fileGroups.has(result.fileName)) { fileGroups.set(result.fileName, []); } fileGroups.get(result.fileName)!.push(result); }); // Build context string const contextParts: string[] = []; fileGroups.forEach((chunks, fileName) => { contextParts.push(this.i18nService.formatMessage('ragSource', { fileName }, lang)); chunks.forEach((chunk, index) => { contextParts.push( this.i18nService.formatMessage('ragSegment', { index: index + 1, score: chunk.score.toFixed(3) }, lang), ); contextParts.push(chunk.content); contextParts.push(''); }); }); context = contextParts.join('\n'); } const langText = lang === 'zh' ? 'Chinese' : lang === 'en' ? 'English' : 'Japanese'; const systemPrompt = this.i18nService.getMessage('ragSystemPrompt', lang); const rules = this.i18nService.formatMessage('ragRules', { lang: langText }, lang); const docContentHeader = this.i18nService.getMessage('ragDocumentContent', lang); const userQuestionHeader = this.i18nService.getMessage('ragUserQuestion', lang); const answerHeader = this.i18nService.getMessage('ragAnswer', lang); return `${systemPrompt} ${rules} ${docContentHeader} ${context} ${userQuestionHeader} ${query} ${answerHeader}`; } extractSources(searchResults: RagSearchResult[]): string[] { const uniqueFiles = new Set(); searchResults.forEach((result) => { uniqueFiles.add(result.fileName); }); return Array.from(uniqueFiles); } /** * Deduplicate search results */ private deduplicateResults(results: any[]): any[] { const unique = new Map(); results.forEach(r => { const key = `${r.fileId}_${r.chunkIndex}`; if (!unique.has(key) || unique.get(key)!.score < r.score) { unique.set(key, r); } }); return Array.from(unique.values()).sort((a, b) => b.score - a.score); } /** * Expand query to generate variations */ async expandQuery(query: string, userId: string, tenantId?: string): Promise { try { const llm = await this.getInternalLlm(userId, tenantId || 'default'); if (!llm) return [query]; const userSettings = await this.userSettingService.getByUser(userId); const lang = userSettings.language || 'zh'; const prompt = this.i18nService.formatMessage('queryExpansionPrompt', { query }, lang); const response = await llm.invoke(prompt); const content = String(response.content); const expandedQueries = content .split('\n') .map(q => q.trim()) .filter(q => q.length > 0) .slice(0, 3); // Limit to maximum 3 this.logger.log(`Query expanded: "${query}" -> [${expandedQueries.join(', ')}]`); return expandedQueries.length > 0 ? expandedQueries : [query]; } catch (error) { this.logger.error('Query expansion failed:', error); return [query]; } } /** * Generate hypothetical document (HyDE) */ async generateHyDE(query: string, userId: string, tenantId?: string): Promise { try { const llm = await this.getInternalLlm(userId, tenantId || 'default'); if (!llm) return query; const userSettings = await this.userSettingService.getByUser(userId); const lang = userSettings.language || 'zh'; const prompt = this.i18nService.formatMessage('hydePrompt', { query }, lang); const response = await llm.invoke(prompt); const hydeDoc = String(response.content).trim(); this.logger.log(`HyDE generated for: "${query}" (length: ${hydeDoc.length})`); return hydeDoc || query; } catch (error) { this.logger.error('HyDE generation failed:', error); return query; } } /** * Get LLM instance for internal tasks */ private async getInternalLlm(userId: string, tenantId: string): Promise { try { const models = await this.modelConfigService.findAll(userId, tenantId || 'default'); const defaultLlm = models.find(m => m.type === 'llm' && m.isDefault && m.isEnabled !== false); if (!defaultLlm) { this.logger.warn('No enabled LLM configured for internal tasks'); return null; } return new ChatOpenAI({ apiKey: defaultLlm.apiKey || 'ollama', temperature: 0.3, modelName: defaultLlm.modelId, configuration: { baseURL: defaultLlm.baseUrl || 'http://localhost:11434/v1', }, }); } catch (error) { this.logger.error('Failed to get internal LLM:', error); return null; } } }