|
|
@@ -6,6 +6,8 @@ import { ModelConfigService } from '../model-config/model-config.service';
|
|
|
import { RerankService } from './rerank.service';
|
|
|
import { I18nService } from '../i18n/i18n.service';
|
|
|
import { UserSettingService } from '../user-setting/user-setting.service';
|
|
|
+import { ChatOpenAI } from '@langchain/openai';
|
|
|
+import { ModelConfig } from '../types';
|
|
|
|
|
|
export interface RagSearchResult {
|
|
|
content: string;
|
|
|
@@ -51,6 +53,8 @@ export class RagService {
|
|
|
selectedGroups?: string[],
|
|
|
effectiveFileIds?: string[],
|
|
|
rerankSimilarityThreshold: number = 0.5, // Rerankのしきい値(デフォルト0.5)
|
|
|
+ enableQueryExpansion?: boolean,
|
|
|
+ enableHyDE?: boolean,
|
|
|
): Promise<RagSearchResult[]> {
|
|
|
// 1. グローバル設定の取得
|
|
|
const globalSettings = await this.userSettingService.getGlobalSettings();
|
|
|
@@ -64,64 +68,71 @@ export class RagService {
|
|
|
const effectiveEmbeddingId = embeddingModelId || globalSettings.selectedEmbeddingId;
|
|
|
const effectiveRerankId = rerankModelId || globalSettings.selectedRerankId;
|
|
|
const effectiveHybridWeight = globalSettings.hybridVectorWeight ?? 0.7;
|
|
|
+ const effectiveEnableQueryExpansion = enableQueryExpansion !== undefined ? enableQueryExpansion : globalSettings.enableQueryExpansion;
|
|
|
+ const effectiveEnableHyDE = enableHyDE !== undefined ? enableHyDE : globalSettings.enableHyDE;
|
|
|
|
|
|
this.logger.log(
|
|
|
- `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}`,
|
|
|
+ `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}, QueryExpansion=${effectiveEnableQueryExpansion}, HyDE=${effectiveEnableHyDE}`,
|
|
|
);
|
|
|
|
|
|
try {
|
|
|
- // 1. クエリベクトルの取得
|
|
|
+ // 1. クエリの準備(拡張または HyDE)
|
|
|
+ let queriesToSearch = [query];
|
|
|
+
|
|
|
+ if (effectiveEnableHyDE) {
|
|
|
+ const hydeDoc = await this.generateHyDE(query, userId);
|
|
|
+ queriesToSearch = [hydeDoc]; // HyDE の場合は仮想ドキュメントをクエリとして使用
|
|
|
+ } else if (effectiveEnableQueryExpansion) {
|
|
|
+ const expanded = await this.expandQuery(query, userId);
|
|
|
+ queriesToSearch = [...new Set([query, ...expanded])];
|
|
|
+ }
|
|
|
+
|
|
|
+ // 埋め込みモデルIDが提供されているか確認
|
|
|
if (!effectiveEmbeddingId) {
|
|
|
throw new Error('埋め込みモデルIDが提供されていません');
|
|
|
}
|
|
|
|
|
|
- const queryEmbedding = await this.embeddingService.getEmbeddings(
|
|
|
- [query],
|
|
|
- userId,
|
|
|
- effectiveEmbeddingId,
|
|
|
- );
|
|
|
- const queryVector = queryEmbedding[0];
|
|
|
-
|
|
|
- this.logger.log(`使用するベクトル次元数: ${queryVector?.length}`);
|
|
|
-
|
|
|
- // 2. 設定に基づいた検索戦略の選択
|
|
|
- let searchResults;
|
|
|
- if (effectiveEnableFullText) {
|
|
|
- // ハイブリッド検索
|
|
|
- // 重要: ここでの 0.7 はハイブリッドの重み。閾値フィルタリングは後で「生のスコア」に対して行う。
|
|
|
- // ElasticsearchService.hybridSearch は内部でベクトル検索と全文検索それぞれのスコアを持つ
|
|
|
- searchResults = await this.elasticsearchService.hybridSearch(
|
|
|
- queryVector,
|
|
|
- query,
|
|
|
- userId,
|
|
|
- effectiveTopK * 4, // Rerankのために少し多めに取得
|
|
|
- effectiveHybridWeight, // vectorWeight
|
|
|
- undefined,
|
|
|
- effectiveFileIds
|
|
|
- );
|
|
|
- } else {
|
|
|
- // ベクトル検索のみ
|
|
|
- let vectorSearchResults = await this.elasticsearchService.searchSimilar(
|
|
|
- queryVector,
|
|
|
+ // 2. 複数のクエリに対して並列検索
|
|
|
+ const searchTasks = queriesToSearch.map(async (searchQuery) => {
|
|
|
+ // クエリベクトルの取得
|
|
|
+ const queryEmbedding = await this.embeddingService.getEmbeddings(
|
|
|
+ [searchQuery],
|
|
|
userId,
|
|
|
- effectiveTopK * 4 // Rerankのために少し多めに取得
|
|
|
+ effectiveEmbeddingId,
|
|
|
);
|
|
|
-
|
|
|
- // フィルタリング
|
|
|
- if (effectiveFileIds && effectiveFileIds.length > 0) {
|
|
|
- searchResults = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId));
|
|
|
+ const queryVector = queryEmbedding[0];
|
|
|
+
|
|
|
+ // 設定に基づいた検索戦略の選択
|
|
|
+ let results;
|
|
|
+ if (effectiveEnableFullText) {
|
|
|
+ results = await this.elasticsearchService.hybridSearch(
|
|
|
+ queryVector,
|
|
|
+ searchQuery,
|
|
|
+ userId,
|
|
|
+ effectiveTopK * 4,
|
|
|
+ effectiveHybridWeight,
|
|
|
+ undefined,
|
|
|
+ effectiveFileIds
|
|
|
+ );
|
|
|
} else {
|
|
|
- searchResults = vectorSearchResults;
|
|
|
+ let vectorSearchResults = await this.elasticsearchService.searchSimilar(
|
|
|
+ queryVector,
|
|
|
+ userId,
|
|
|
+ effectiveTopK * 4
|
|
|
+ );
|
|
|
+ if (effectiveFileIds && effectiveFileIds.length > 0) {
|
|
|
+ results = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId));
|
|
|
+ } else {
|
|
|
+ results = vectorSearchResults;
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
+ return results;
|
|
|
+ });
|
|
|
|
|
|
- // 初回の類似度フィルタリング
|
|
|
- // 修正: ハイブリッド検索の場合、各要素の raw スコア(加重計算前)でチェックするのが理想だが、
|
|
|
- // 現状の hybridSearch は combinedScore を .score に入れている。
|
|
|
- // ただし、もし vectorWeight=0.7 の場合、相似度 0.4 のものは 0.28 になり、0.3 閾値で消えてしまう。
|
|
|
- // これを避けるため、閾値チェックを「加重計算の影響を考慮した値」または「加重計算前」に行う必要がある。
|
|
|
- // ここでは、ユーザーの期待に合わせるため、フィルタリングロジックを微調整する。
|
|
|
+ const allResultsRaw = await Promise.all(searchTasks);
|
|
|
+ let searchResults = this.deduplicateResults(allResultsRaw.flat());
|
|
|
|
|
|
+ // 初回の類似度フィルタリング
|
|
|
const initialCount = searchResults.length;
|
|
|
|
|
|
// ログ出力
|
|
|
@@ -255,4 +266,97 @@ ${answerHeader}`;
|
|
|
});
|
|
|
return Array.from(uniqueFiles);
|
|
|
}
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 検索結果の重複排除
|
|
|
+ */
|
|
|
+ private deduplicateResults(results: any[]): any[] {
|
|
|
+ const unique = new Map<string, any>();
|
|
|
+ results.forEach(r => {
|
|
|
+ const key = `${r.fileId}_${r.chunkIndex}`;
|
|
|
+ if (!unique.has(key) || unique.get(key)!.score < r.score) {
|
|
|
+ unique.set(key, r);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ return Array.from(unique.values()).sort((a, b) => b.score - a.score);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * クエリを拡張してバリエーションを生成
|
|
|
+ */
|
|
|
+ async expandQuery(query: string, userId: string): Promise<string[]> {
|
|
|
+ try {
|
|
|
+ const llm = await this.getInternalLlm(userId);
|
|
|
+ if (!llm) return [query];
|
|
|
+
|
|
|
+ const userSettings = await this.userSettingService.findOrCreate(userId);
|
|
|
+ const lang = userSettings.language || 'ja';
|
|
|
+ const prompt = this.i18nService.formatMessage('queryExpansionPrompt', { query }, lang);
|
|
|
+
|
|
|
+ const response = await llm.invoke(prompt);
|
|
|
+ const content = String(response.content);
|
|
|
+
|
|
|
+ const expandedQueries = content
|
|
|
+ .split('\n')
|
|
|
+ .map(q => q.trim())
|
|
|
+ .filter(q => q.length > 0)
|
|
|
+ .slice(0, 3); // 最大3つに制限
|
|
|
+
|
|
|
+ this.logger.log(`Query expanded: "${query}" -> [${expandedQueries.join(', ')}]`);
|
|
|
+ return expandedQueries.length > 0 ? expandedQueries : [query];
|
|
|
+ } catch (error) {
|
|
|
+ this.logger.error('Query expansion failed:', error);
|
|
|
+ return [query];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 仮想的なドキュメント(HyDE)を生成
|
|
|
+ */
|
|
|
+ async generateHyDE(query: string, userId: string): Promise<string> {
|
|
|
+ try {
|
|
|
+ const llm = await this.getInternalLlm(userId);
|
|
|
+ if (!llm) return query;
|
|
|
+
|
|
|
+ const userSettings = await this.userSettingService.findOrCreate(userId);
|
|
|
+ const lang = userSettings.language || 'ja';
|
|
|
+ const prompt = this.i18nService.formatMessage('hydePrompt', { query }, lang);
|
|
|
+
|
|
|
+ const response = await llm.invoke(prompt);
|
|
|
+ const hydeDoc = String(response.content).trim();
|
|
|
+
|
|
|
+ this.logger.log(`HyDE generated for: "${query}" (length: ${hydeDoc.length})`);
|
|
|
+ return hydeDoc || query;
|
|
|
+ } catch (error) {
|
|
|
+ this.logger.error('HyDE generation failed:', error);
|
|
|
+ return query;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 内部タスク用の LLM インスタンスを取得
|
|
|
+ */
|
|
|
+ private async getInternalLlm(userId: string): Promise<ChatOpenAI | null> {
|
|
|
+ try {
|
|
|
+ const models = await this.modelConfigService.findAll(userId);
|
|
|
+ const defaultLlm = models.find(m => m.type === 'llm' && m.isDefault && m.isEnabled !== false);
|
|
|
+
|
|
|
+ if (!defaultLlm) {
|
|
|
+ this.logger.warn('No default LLM configured for internal tasks');
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ return new ChatOpenAI({
|
|
|
+ apiKey: defaultLlm.apiKey || 'ollama',
|
|
|
+ temperature: 0.3,
|
|
|
+ modelName: defaultLlm.modelId,
|
|
|
+ configuration: {
|
|
|
+ baseURL: defaultLlm.baseUrl || 'http://localhost:11434/v1',
|
|
|
+ },
|
|
|
+ });
|
|
|
+ } catch (error) {
|
|
|
+ this.logger.error('Failed to get internal LLM:', error);
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|