rag.service.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import { Injectable, Logger, Inject, forwardRef } from '@nestjs/common';
  2. import { ConfigService } from '@nestjs/config';
  3. import { ElasticsearchService } from '../elasticsearch/elasticsearch.service';
  4. import { EmbeddingService } from '../knowledge-base/embedding.service';
  5. import { ModelConfigService } from '../model-config/model-config.service';
  6. import { RerankService } from './rerank.service';
  7. import { I18nService } from '../i18n/i18n.service';
  8. import { TenantService } from '../tenant/tenant.service';
  9. import { ChatOpenAI } from '@langchain/openai';
  10. import { ModelConfig } from '../types';
  11. import { UserSettingService } from '../user/user-setting.service';
  12. export interface RagSearchResult {
  13. content: string;
  14. fileName: string;
  15. score: number;
  16. chunkIndex: number;
  17. fileId?: string;
  18. originalScore?: number; // Original score before reranking (for debugging)
  19. metadata?: any;
  20. }
  21. @Injectable()
  22. export class RagService {
  23. private readonly logger = new Logger(RagService.name);
  24. private readonly defaultDimensions: number;
  25. constructor(
  26. @Inject(forwardRef(() => ElasticsearchService))
  27. private elasticsearchService: ElasticsearchService,
  28. private embeddingService: EmbeddingService,
  29. private modelConfigService: ModelConfigService,
  30. private rerankService: RerankService,
  31. private configService: ConfigService,
  32. private i18nService: I18nService,
  33. private tenantService: TenantService,
  34. private userSettingService: UserSettingService,
  35. ) {
  36. this.defaultDimensions = parseInt(
  37. this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
  38. );
  39. this.logger.log(`RAG service default vector dimensions: ${this.defaultDimensions}`);
  40. }
  41. async searchKnowledge(
  42. query: string,
  43. userId: string,
  44. topK: number = 5,
  45. vectorSimilarityThreshold: number = 0.3, // Vector search similarity threshold
  46. embeddingModelId?: string,
  47. enableFullTextSearch: boolean = false,
  48. enableRerank: boolean = false,
  49. rerankModelId?: string,
  50. selectedGroups?: string[],
  51. effectiveFileIds?: string[],
  52. rerankSimilarityThreshold: number = 0.5, // Rerank similarity threshold (default 0.5)
  53. tenantId?: string, // New
  54. enableQueryExpansion?: boolean,
  55. enableHyDE?: boolean,
  56. ): Promise<RagSearchResult[]> {
  57. // 1. Get organization settings
  58. const globalSettings = await this.tenantService.getSettings(tenantId || 'default');
  59. // Use global settings if parameters are not explicitly provided
  60. const effectiveTopK = topK || globalSettings?.topK || 5;
  61. const effectiveVectorThreshold = vectorSimilarityThreshold !== undefined ? vectorSimilarityThreshold : (globalSettings?.similarityThreshold || 0.3);
  62. const effectiveRerankThreshold = rerankSimilarityThreshold !== undefined ? rerankSimilarityThreshold : (globalSettings?.rerankSimilarityThreshold || 0.5);
  63. const effectiveEnableRerank = enableRerank !== undefined ? enableRerank : (globalSettings?.enableRerank ?? false);
  64. const effectiveEnableFullText = enableFullTextSearch !== undefined ? enableFullTextSearch : (globalSettings?.enableFullTextSearch ?? false);
  65. const effectiveEmbeddingId = embeddingModelId || globalSettings?.selectedEmbeddingId;
  66. const effectiveRerankId = rerankModelId || globalSettings?.selectedRerankId;
  67. const effectiveHybridWeight = globalSettings?.hybridVectorWeight ?? 0.7;
  68. const effectiveEnableQueryExpansion = enableQueryExpansion !== undefined ? enableQueryExpansion : (globalSettings?.enableQueryExpansion ?? false);
  69. const effectiveEnableHyDE = enableHyDE !== undefined ? enableHyDE : (globalSettings?.enableHyDE ?? false);
  70. this.logger.log(
  71. `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}, QueryExpansion=${effectiveEnableQueryExpansion}, HyDE=${effectiveEnableHyDE}`,
  72. );
  73. try {
  74. // 1. Prepare query (expansion or HyDE)
  75. let queriesToSearch = [query];
  76. if (effectiveEnableHyDE) {
  77. const hydeDoc = await this.generateHyDE(query, userId);
  78. queriesToSearch = [hydeDoc]; // Use virtual document as query for HyDE
  79. } else if (effectiveEnableQueryExpansion) {
  80. const expanded = await this.expandQuery(query, userId);
  81. queriesToSearch = [...new Set([query, ...expanded])];
  82. }
  83. // Check if embedding model ID is provided
  84. if (!effectiveEmbeddingId) {
  85. throw new Error(this.i18nService.getMessage('embeddingModelIdNotProvided'));
  86. }
  87. // 2. Parallel search for multiple queries
  88. const searchTasks = queriesToSearch.map(async (searchQuery) => {
  89. // Get query vector
  90. const queryEmbedding = await this.embeddingService.getEmbeddings(
  91. [searchQuery],
  92. userId,
  93. effectiveEmbeddingId,
  94. );
  95. const queryVector = queryEmbedding[0];
  96. // Select search strategy based on settings
  97. let results;
  98. if (effectiveEnableFullText) {
  99. results = await this.elasticsearchService.hybridSearch(
  100. queryVector,
  101. searchQuery,
  102. userId,
  103. effectiveTopK * 4,
  104. effectiveHybridWeight,
  105. undefined,
  106. effectiveFileIds,
  107. tenantId,
  108. );
  109. } else {
  110. let vectorSearchResults = await this.elasticsearchService.searchSimilar(
  111. queryVector,
  112. userId,
  113. effectiveTopK * 4,
  114. tenantId,
  115. );
  116. if (effectiveFileIds && effectiveFileIds.length > 0) {
  117. results = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId));
  118. } else {
  119. results = vectorSearchResults;
  120. }
  121. }
  122. return results;
  123. });
  124. const allResultsRaw = await Promise.all(searchTasks);
  125. let searchResults = this.deduplicateResults(allResultsRaw.flat());
  126. // Initial similarity filtering
  127. const initialCount = searchResults.length;
  128. // Log output
  129. searchResults.forEach((r, idx) => {
  130. this.logger.log(`Hit ${idx}: score=${r.score.toFixed(4)}, fileName=${r.fileName}`);
  131. });
  132. // Apply threshold filtering
  133. searchResults = searchResults.filter(r => r.score >= effectiveVectorThreshold);
  134. this.logger.log(`Initial hits: ${initialCount} -> filtered by vectorThreshold: ${searchResults.length}`);
  135. // 3. Rerank
  136. let finalResults = searchResults;
  137. if (effectiveEnableRerank && effectiveRerankId && searchResults.length > 0) {
  138. try {
  139. const docs = searchResults.map(r => r.content);
  140. const rerankedIndices = await this.rerankService.rerank(
  141. query,
  142. docs,
  143. userId,
  144. effectiveRerankId,
  145. effectiveTopK * 2 // Keep a bit more results
  146. );
  147. finalResults = rerankedIndices.map(r => {
  148. const originalItem = searchResults[r.index];
  149. return {
  150. ...originalItem,
  151. score: r.score, // Rerank score
  152. originalScore: originalItem.score // Original score
  153. };
  154. });
  155. // Filter after reranking
  156. const beforeRerankFilter = finalResults.length;
  157. finalResults = finalResults.filter(r => r.score >= effectiveRerankThreshold);
  158. this.logger.log(`After rerank: ${beforeRerankFilter} -> filtered by rerankThreshold: ${finalResults.length}`);
  159. } catch (error) {
  160. this.logger.warn(`Rerank failed, falling back to filtered vector search: ${error.message}`);
  161. // Fall back to filtered vector search results if rerank fails
  162. }
  163. }
  164. // Final result count limit
  165. finalResults = finalResults.slice(0, effectiveTopK);
  166. // 4. Convert to RAG result format
  167. const ragResults: RagSearchResult[] = finalResults.map((result) => ({
  168. content: result.content,
  169. fileName: result.fileName,
  170. score: result.score,
  171. originalScore: result.originalScore !== undefined ? result.originalScore : result.score,
  172. chunkIndex: result.chunkIndex,
  173. fileId: result.fileId,
  174. metadata: result.metadata,
  175. }));
  176. return ragResults;
  177. } catch (error) {
  178. this.logger.error('RAG search failed:', error);
  179. return [];
  180. }
  181. }
  182. buildRagPrompt(
  183. query: string,
  184. searchResults: RagSearchResult[],
  185. language: string = 'ja',
  186. ): string {
  187. const lang = language || 'ja';
  188. // Build context
  189. let context = '';
  190. if (searchResults.length === 0) {
  191. context = this.i18nService.getMessage('ragNoDocumentFound', lang);
  192. } else {
  193. // Group by file
  194. const fileGroups = new Map<string, RagSearchResult[]>();
  195. searchResults.forEach((result) => {
  196. if (!fileGroups.has(result.fileName)) {
  197. fileGroups.set(result.fileName, []);
  198. }
  199. fileGroups.get(result.fileName)!.push(result);
  200. });
  201. // Build context string
  202. const contextParts: string[] = [];
  203. fileGroups.forEach((chunks, fileName) => {
  204. contextParts.push(this.i18nService.formatMessage('ragSource', { fileName }, lang));
  205. chunks.forEach((chunk, index) => {
  206. contextParts.push(
  207. this.i18nService.formatMessage('ragSegment', {
  208. index: index + 1,
  209. score: chunk.score.toFixed(3)
  210. }, lang),
  211. );
  212. contextParts.push(chunk.content);
  213. contextParts.push('');
  214. });
  215. });
  216. context = contextParts.join('\n');
  217. }
  218. const langText = lang === 'zh' ? 'Chinese' : lang === 'en' ? 'English' : 'Japanese';
  219. const systemPrompt = this.i18nService.getMessage('ragSystemPrompt', lang);
  220. const rules = this.i18nService.formatMessage('ragRules', { lang: langText }, lang);
  221. const docContentHeader = this.i18nService.getMessage('ragDocumentContent', lang);
  222. const userQuestionHeader = this.i18nService.getMessage('ragUserQuestion', lang);
  223. const answerHeader = this.i18nService.getMessage('ragAnswer', lang);
  224. return `${systemPrompt}
  225. ${rules}
  226. ${docContentHeader}
  227. ${context}
  228. ${userQuestionHeader}
  229. ${query}
  230. ${answerHeader}`;
  231. }
  232. extractSources(searchResults: RagSearchResult[]): string[] {
  233. const uniqueFiles = new Set<string>();
  234. searchResults.forEach((result) => {
  235. uniqueFiles.add(result.fileName);
  236. });
  237. return Array.from(uniqueFiles);
  238. }
  239. /**
  240. * Deduplicate search results
  241. */
  242. private deduplicateResults(results: any[]): any[] {
  243. const unique = new Map<string, any>();
  244. results.forEach(r => {
  245. const key = `${r.fileId}_${r.chunkIndex}`;
  246. if (!unique.has(key) || unique.get(key)!.score < r.score) {
  247. unique.set(key, r);
  248. }
  249. });
  250. return Array.from(unique.values()).sort((a, b) => b.score - a.score);
  251. }
  252. /**
  253. * Expand query to generate variations
  254. */
  255. async expandQuery(query: string, userId: string, tenantId?: string): Promise<string[]> {
  256. try {
  257. const llm = await this.getInternalLlm(userId, tenantId || 'default');
  258. if (!llm) return [query];
  259. const userSettings = await this.userSettingService.getByUser(userId);
  260. const lang = userSettings.language || 'zh';
  261. const prompt = this.i18nService.formatMessage('queryExpansionPrompt', { query }, lang);
  262. const response = await llm.invoke(prompt);
  263. const content = String(response.content);
  264. const expandedQueries = content
  265. .split('\n')
  266. .map(q => q.trim())
  267. .filter(q => q.length > 0)
  268. .slice(0, 3); // Limit to maximum 3
  269. this.logger.log(`Query expanded: "${query}" -> [${expandedQueries.join(', ')}]`);
  270. return expandedQueries.length > 0 ? expandedQueries : [query];
  271. } catch (error) {
  272. this.logger.error('Query expansion failed:', error);
  273. return [query];
  274. }
  275. }
  276. /**
  277. * Generate hypothetical document (HyDE)
  278. */
  279. async generateHyDE(query: string, userId: string, tenantId?: string): Promise<string> {
  280. try {
  281. const llm = await this.getInternalLlm(userId, tenantId || 'default');
  282. if (!llm) return query;
  283. const userSettings = await this.userSettingService.getByUser(userId);
  284. const lang = userSettings.language || 'zh';
  285. const prompt = this.i18nService.formatMessage('hydePrompt', { query }, lang);
  286. const response = await llm.invoke(prompt);
  287. const hydeDoc = String(response.content).trim();
  288. this.logger.log(`HyDE generated for: "${query}" (length: ${hydeDoc.length})`);
  289. return hydeDoc || query;
  290. } catch (error) {
  291. this.logger.error('HyDE generation failed:', error);
  292. return query;
  293. }
  294. }
  295. /**
  296. * Get LLM instance for internal tasks
  297. */
  298. private async getInternalLlm(userId: string, tenantId: string): Promise<ChatOpenAI | null> {
  299. try {
  300. const models = await this.modelConfigService.findAll(userId, tenantId || 'default');
  301. const defaultLlm = models.find(m => m.type === 'llm' && m.isDefault && m.isEnabled !== false);
  302. if (!defaultLlm) {
  303. this.logger.warn('No enabled LLM configured for internal tasks');
  304. return null;
  305. }
  306. return new ChatOpenAI({
  307. apiKey: defaultLlm.apiKey || 'ollama',
  308. temperature: 0.3,
  309. modelName: defaultLlm.modelId,
  310. configuration: {
  311. baseURL: defaultLlm.baseUrl || 'http://localhost:11434/v1',
  312. },
  313. });
  314. } catch (error) {
  315. this.logger.error('Failed to get internal LLM:', error);
  316. return null;
  317. }
  318. }
  319. }