rag.service.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import { Injectable, Logger, Inject, forwardRef } from '@nestjs/common';
  2. import { ConfigService } from '@nestjs/config';
  3. import { ElasticsearchService } from '../elasticsearch/elasticsearch.service';
  4. import { EmbeddingService } from '../knowledge-base/embedding.service';
  5. import { ModelConfigService } from '../model-config/model-config.service';
  6. import { RerankService } from './rerank.service';
  7. import { I18nService } from '../i18n/i18n.service';
  8. import { UserSettingService } from '../user-setting/user-setting.service';
  9. import { ChatOpenAI } from '@langchain/openai';
  10. import { ModelConfig } from '../types';
  11. export interface RagSearchResult {
  12. content: string;
  13. fileName: string;
  14. score: number;
  15. chunkIndex: number;
  16. fileId?: string;
  17. originalScore?: number; // Rerank前のスコア(デバッグ用)
  18. metadata?: any;
  19. }
  20. @Injectable()
  21. export class RagService {
  22. private readonly logger = new Logger(RagService.name);
  23. private readonly defaultDimensions: number;
  24. constructor(
  25. @Inject(forwardRef(() => ElasticsearchService))
  26. private elasticsearchService: ElasticsearchService,
  27. private embeddingService: EmbeddingService,
  28. private modelConfigService: ModelConfigService,
  29. private rerankService: RerankService,
  30. private configService: ConfigService,
  31. private i18nService: I18nService,
  32. private userSettingService: UserSettingService,
  33. ) {
  34. this.defaultDimensions = parseInt(
  35. this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
  36. );
  37. this.logger.log(`RAG サービスのデフォルトベクトル次元数: ${this.defaultDimensions}`);
  38. }
  39. async searchKnowledge(
  40. query: string,
  41. userId: string,
  42. topK: number = 5,
  43. vectorSimilarityThreshold: number = 0.3, // ベクトル検索のしきい値
  44. embeddingModelId?: string,
  45. enableFullTextSearch: boolean = false,
  46. enableRerank: boolean = false,
  47. rerankModelId?: string,
  48. selectedGroups?: string[],
  49. effectiveFileIds?: string[],
  50. rerankSimilarityThreshold: number = 0.5, // Rerankのしきい値(デフォルト0.5)
  51. tenantId?: string, // New
  52. enableQueryExpansion?: boolean,
  53. enableHyDE?: boolean,
  54. ): Promise<RagSearchResult[]> {
  55. // 1. グローバル設定の取得
  56. const globalSettings = await this.userSettingService.getGlobalSettings();
  57. // パラメータが明示的に渡されていない場合はグローバル設定を使用
  58. const effectiveTopK = topK || globalSettings.topK || 5;
  59. const effectiveVectorThreshold = vectorSimilarityThreshold !== undefined ? vectorSimilarityThreshold : (globalSettings.similarityThreshold || 0.3);
  60. const effectiveRerankThreshold = rerankSimilarityThreshold !== undefined ? rerankSimilarityThreshold : (globalSettings.rerankSimilarityThreshold || 0.5);
  61. const effectiveEnableRerank = enableRerank !== undefined ? enableRerank : globalSettings.enableRerank;
  62. const effectiveEnableFullText = enableFullTextSearch !== undefined ? enableFullTextSearch : globalSettings.enableFullTextSearch;
  63. const effectiveEmbeddingId = embeddingModelId || globalSettings.selectedEmbeddingId;
  64. const effectiveRerankId = rerankModelId || globalSettings.selectedRerankId;
  65. const effectiveHybridWeight = globalSettings.hybridVectorWeight ?? 0.7;
  66. const effectiveEnableQueryExpansion = enableQueryExpansion !== undefined ? enableQueryExpansion : globalSettings.enableQueryExpansion;
  67. const effectiveEnableHyDE = enableHyDE !== undefined ? enableHyDE : globalSettings.enableHyDE;
  68. this.logger.log(
  69. `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}, QueryExpansion=${effectiveEnableQueryExpansion}, HyDE=${effectiveEnableHyDE}`,
  70. );
  71. try {
  72. // 1. クエリの準備(拡張または HyDE)
  73. let queriesToSearch = [query];
  74. if (effectiveEnableHyDE) {
  75. const hydeDoc = await this.generateHyDE(query, userId);
  76. queriesToSearch = [hydeDoc]; // HyDE の場合は仮想ドキュメントをクエリとして使用
  77. } else if (effectiveEnableQueryExpansion) {
  78. const expanded = await this.expandQuery(query, userId);
  79. queriesToSearch = [...new Set([query, ...expanded])];
  80. }
  81. // 埋め込みモデルIDが提供されているか確認
  82. if (!effectiveEmbeddingId) {
  83. throw new Error('埋め込みモデルIDが提供されていません');
  84. }
  85. // 2. 複数のクエリに対して並列検索
  86. const searchTasks = queriesToSearch.map(async (searchQuery) => {
  87. // クエリベクトルの取得
  88. const queryEmbedding = await this.embeddingService.getEmbeddings(
  89. [searchQuery],
  90. userId,
  91. effectiveEmbeddingId,
  92. );
  93. const queryVector = queryEmbedding[0];
  94. // 設定に基づいた検索戦略の選択
  95. let results;
  96. if (effectiveEnableFullText) {
  97. results = await this.elasticsearchService.hybridSearch(
  98. queryVector,
  99. searchQuery,
  100. userId,
  101. effectiveTopK * 4,
  102. effectiveHybridWeight,
  103. undefined,
  104. effectiveFileIds,
  105. tenantId,
  106. );
  107. } else {
  108. let vectorSearchResults = await this.elasticsearchService.searchSimilar(
  109. queryVector,
  110. userId,
  111. effectiveTopK * 4,
  112. tenantId,
  113. );
  114. if (effectiveFileIds && effectiveFileIds.length > 0) {
  115. results = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId));
  116. } else {
  117. results = vectorSearchResults;
  118. }
  119. }
  120. return results;
  121. });
  122. const allResultsRaw = await Promise.all(searchTasks);
  123. let searchResults = this.deduplicateResults(allResultsRaw.flat());
  124. // 初回の類似度フィルタリング
  125. const initialCount = searchResults.length;
  126. // ログ出力
  127. searchResults.forEach((r, idx) => {
  128. this.logger.log(`Hit ${idx}: score=${r.score.toFixed(4)}, fileName=${r.fileName}`);
  129. });
  130. // 閾値フィルタリングを適用
  131. searchResults = searchResults.filter(r => r.score >= effectiveVectorThreshold);
  132. this.logger.log(`Initial hits: ${initialCount} -> filtered by vectorThreshold: ${searchResults.length}`);
  133. // 3. リランク (Rerank)
  134. let finalResults = searchResults;
  135. if (effectiveEnableRerank && effectiveRerankId && searchResults.length > 0) {
  136. try {
  137. const docs = searchResults.map(r => r.content);
  138. const rerankedIndices = await this.rerankService.rerank(
  139. query,
  140. docs,
  141. userId,
  142. effectiveRerankId,
  143. effectiveTopK * 2 // 少し多めに残す
  144. );
  145. finalResults = rerankedIndices.map(r => {
  146. const originalItem = searchResults[r.index];
  147. return {
  148. ...originalItem,
  149. score: r.score, // Rerank スコア
  150. originalScore: originalItem.score // 元のスコア
  151. };
  152. });
  153. // Rerank後のフィルタリング
  154. const beforeRerankFilter = finalResults.length;
  155. finalResults = finalResults.filter(r => r.score >= effectiveRerankThreshold);
  156. this.logger.log(`After rerank: ${beforeRerankFilter} -> filtered by rerankThreshold: ${finalResults.length}`);
  157. } catch (error) {
  158. this.logger.warn(`Rerank failed, falling back to filtered vector search: ${error.message}`);
  159. // 失敗した場合はベクトル検索の結果をそのまま使う
  160. }
  161. }
  162. // 最終的な件数制限
  163. finalResults = finalResults.slice(0, effectiveTopK);
  164. // 4. RAG 結果形式に変換
  165. const ragResults: RagSearchResult[] = finalResults.map((result) => ({
  166. content: result.content,
  167. fileName: result.fileName,
  168. score: result.score,
  169. originalScore: result.originalScore !== undefined ? result.originalScore : result.score,
  170. chunkIndex: result.chunkIndex,
  171. fileId: result.fileId,
  172. metadata: result.metadata,
  173. }));
  174. return ragResults;
  175. } catch (error) {
  176. this.logger.error('RAG search failed:', error);
  177. return [];
  178. }
  179. }
  180. buildRagPrompt(
  181. query: string,
  182. searchResults: RagSearchResult[],
  183. language: string = 'ja',
  184. ): string {
  185. const lang = language || 'ja';
  186. // コンテキストの構築
  187. let context = '';
  188. if (searchResults.length === 0) {
  189. context = this.i18nService.getMessage('ragNoDocumentFound', lang);
  190. } else {
  191. // ファイルごとにグループ化
  192. const fileGroups = new Map<string, RagSearchResult[]>();
  193. searchResults.forEach((result) => {
  194. if (!fileGroups.has(result.fileName)) {
  195. fileGroups.set(result.fileName, []);
  196. }
  197. fileGroups.get(result.fileName)!.push(result);
  198. });
  199. // コンテキスト文字列を構築
  200. const contextParts: string[] = [];
  201. fileGroups.forEach((chunks, fileName) => {
  202. contextParts.push(this.i18nService.formatMessage('ragSource', { fileName }, lang));
  203. chunks.forEach((chunk, index) => {
  204. contextParts.push(
  205. this.i18nService.formatMessage('ragSegment', {
  206. index: index + 1,
  207. score: chunk.score.toFixed(3)
  208. }, lang),
  209. );
  210. contextParts.push(chunk.content);
  211. contextParts.push('');
  212. });
  213. });
  214. context = contextParts.join('\n');
  215. }
  216. const langText =
  217. lang === 'zh' ? '中文' : lang === 'en' ? 'English' : '日本語';
  218. const systemPrompt = this.i18nService.getMessage('ragSystemPrompt', lang);
  219. const rules = this.i18nService.formatMessage('ragRules', { lang: langText }, lang);
  220. const docContentHeader = this.i18nService.getMessage('ragDocumentContent', lang);
  221. const userQuestionHeader = this.i18nService.getMessage('ragUserQuestion', lang);
  222. const answerHeader = this.i18nService.getMessage('ragAnswer', lang);
  223. return `${systemPrompt}
  224. ${rules}
  225. ${docContentHeader}
  226. ${context}
  227. ${userQuestionHeader}
  228. ${query}
  229. ${answerHeader}`;
  230. }
  231. extractSources(searchResults: RagSearchResult[]): string[] {
  232. const uniqueFiles = new Set<string>();
  233. searchResults.forEach((result) => {
  234. uniqueFiles.add(result.fileName);
  235. });
  236. return Array.from(uniqueFiles);
  237. }
  238. /**
  239. * 検索結果の重複排除
  240. */
  241. private deduplicateResults(results: any[]): any[] {
  242. const unique = new Map<string, any>();
  243. results.forEach(r => {
  244. const key = `${r.fileId}_${r.chunkIndex}`;
  245. if (!unique.has(key) || unique.get(key)!.score < r.score) {
  246. unique.set(key, r);
  247. }
  248. });
  249. return Array.from(unique.values()).sort((a, b) => b.score - a.score);
  250. }
  251. /**
  252. * クエリを拡張してバリエーションを生成
  253. */
  254. async expandQuery(query: string, userId: string, tenantId?: string): Promise<string[]> {
  255. try {
  256. const llm = await this.getInternalLlm(userId, tenantId || 'default');
  257. if (!llm) return [query];
  258. const userSettings = await this.userSettingService.findOrCreate(userId);
  259. const lang = userSettings.language || 'ja';
  260. const prompt = this.i18nService.formatMessage('queryExpansionPrompt', { query }, lang);
  261. const response = await llm.invoke(prompt);
  262. const content = String(response.content);
  263. const expandedQueries = content
  264. .split('\n')
  265. .map(q => q.trim())
  266. .filter(q => q.length > 0)
  267. .slice(0, 3); // 最大3つに制限
  268. this.logger.log(`Query expanded: "${query}" -> [${expandedQueries.join(', ')}]`);
  269. return expandedQueries.length > 0 ? expandedQueries : [query];
  270. } catch (error) {
  271. this.logger.error('Query expansion failed:', error);
  272. return [query];
  273. }
  274. }
  275. /**
  276. * 仮想的なドキュメント(HyDE)を生成
  277. */
  278. async generateHyDE(query: string, userId: string, tenantId?: string): Promise<string> {
  279. try {
  280. const llm = await this.getInternalLlm(userId, tenantId || 'default');
  281. if (!llm) return query;
  282. const userSettings = await this.userSettingService.findOrCreate(userId);
  283. const lang = userSettings.language || 'ja';
  284. const prompt = this.i18nService.formatMessage('hydePrompt', { query }, lang);
  285. const response = await llm.invoke(prompt);
  286. const hydeDoc = String(response.content).trim();
  287. this.logger.log(`HyDE generated for: "${query}" (length: ${hydeDoc.length})`);
  288. return hydeDoc || query;
  289. } catch (error) {
  290. this.logger.error('HyDE generation failed:', error);
  291. return query;
  292. }
  293. }
  294. /**
  295. * 内部タスク用の LLM インスタンスを取得
  296. */
  297. private async getInternalLlm(userId: string, tenantId: string): Promise<ChatOpenAI | null> {
  298. try {
  299. const models = await this.modelConfigService.findAll(userId, tenantId || 'default');
  300. const defaultLlm = models.find(m => m.type === 'llm' && m.isDefault && m.isEnabled !== false);
  301. if (!defaultLlm) {
  302. this.logger.warn('No default LLM configured for internal tasks');
  303. return null;
  304. }
  305. return new ChatOpenAI({
  306. apiKey: defaultLlm.apiKey || 'ollama',
  307. temperature: 0.3,
  308. modelName: defaultLlm.modelId,
  309. configuration: {
  310. baseURL: defaultLlm.baseUrl || 'http://localhost:11434/v1',
  311. },
  312. });
  313. } catch (error) {
  314. this.logger.error('Failed to get internal LLM:', error);
  315. return null;
  316. }
  317. }
  318. }