|
|
@@ -1,366 +1,156 @@
|
|
|
-import { Injectable, Logger, Inject, forwardRef } from '@nestjs/common';
|
|
|
-import { ConfigService } from '@nestjs/config';
|
|
|
+import { Injectable, Logger } from '@nestjs/common';
|
|
|
import { ElasticsearchService } from '../elasticsearch/elasticsearch.service';
|
|
|
import { EmbeddingService } from '../knowledge-base/embedding.service';
|
|
|
-import { ModelConfigService } from '../model-config/model-config.service';
|
|
|
import { RerankService } from './rerank.service';
|
|
|
import { I18nService } from '../i18n/i18n.service';
|
|
|
-import { TenantService } from '../tenant/tenant.service';
|
|
|
-import { ChatOpenAI } from '@langchain/openai';
|
|
|
-import { ModelConfig } from '../types';
|
|
|
-import { UserSettingService } from '../user/user-setting.service';
|
|
|
import { DEFAULT_LANGUAGE } from '../common/constants';
|
|
|
|
|
|
export interface RagSearchResult {
|
|
|
+ id: string;
|
|
|
+ score: number;
|
|
|
content: string;
|
|
|
+ fileId: string;
|
|
|
fileName: string;
|
|
|
- score: number;
|
|
|
+ title: string;
|
|
|
chunkIndex: number;
|
|
|
- fileId?: string;
|
|
|
- originalScore?: number; // Original score before reranking (for debugging)
|
|
|
- metadata?: any;
|
|
|
+ startPosition?: number;
|
|
|
+ endPosition?: number;
|
|
|
}
|
|
|
|
|
|
@Injectable()
|
|
|
export class RagService {
|
|
|
private readonly logger = new Logger(RagService.name);
|
|
|
- private readonly defaultDimensions: number;
|
|
|
|
|
|
constructor(
|
|
|
- @Inject(forwardRef(() => ElasticsearchService))
|
|
|
- private elasticsearchService: ElasticsearchService,
|
|
|
- private embeddingService: EmbeddingService,
|
|
|
- private modelConfigService: ModelConfigService,
|
|
|
- private rerankService: RerankService,
|
|
|
- private configService: ConfigService,
|
|
|
- private i18nService: I18nService,
|
|
|
- private tenantService: TenantService,
|
|
|
- private userSettingService: UserSettingService,
|
|
|
- ) {
|
|
|
- this.defaultDimensions = parseInt(
|
|
|
- this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
|
|
|
- );
|
|
|
- this.logger.log(`RAG service default vector dimensions: ${this.defaultDimensions}`);
|
|
|
- }
|
|
|
+ private readonly elasticsearchService: ElasticsearchService,
|
|
|
+ private readonly embeddingService: EmbeddingService,
|
|
|
+ private readonly rerankService: RerankService,
|
|
|
+ private readonly i18nService: I18nService,
|
|
|
+ ) { }
|
|
|
|
|
|
+ /**
|
|
|
+ * Main research method for RAG.
|
|
|
+ * Supports hybrid search, reranking, query expansion, and HyDE.
|
|
|
+ */
|
|
|
async searchKnowledge(
|
|
|
query: string,
|
|
|
userId: string,
|
|
|
topK: number = 5,
|
|
|
- vectorSimilarityThreshold: number = 0.3, // Vector search similarity threshold
|
|
|
+ similarityThreshold: number = 0.0,
|
|
|
embeddingModelId?: string,
|
|
|
- enableFullTextSearch: boolean = false,
|
|
|
+ enableFullTextSearch: boolean = true,
|
|
|
enableRerank: boolean = false,
|
|
|
rerankModelId?: string,
|
|
|
selectedGroups?: string[],
|
|
|
- effectiveFileIds?: string[],
|
|
|
- rerankSimilarityThreshold: number = 0.5, // Rerank similarity threshold (default 0.5)
|
|
|
- tenantId?: string, // New
|
|
|
- enableQueryExpansion?: boolean,
|
|
|
- enableHyDE?: boolean,
|
|
|
+ selectedFiles?: string[],
|
|
|
+ rerankThreshold: number = 0.0,
|
|
|
+ tenantId?: string,
|
|
|
+ enableQueryExpansion: boolean = false,
|
|
|
+ enableHyDE: boolean = false,
|
|
|
+ language: string = DEFAULT_LANGUAGE,
|
|
|
): Promise<RagSearchResult[]> {
|
|
|
- // 1. Get organization settings
|
|
|
- const globalSettings = await this.tenantService.getSettings(tenantId || 'default');
|
|
|
-
|
|
|
- // Use global settings if parameters are not explicitly provided
|
|
|
- const effectiveTopK = topK || globalSettings?.topK || 5;
|
|
|
- const effectiveVectorThreshold = vectorSimilarityThreshold !== undefined ? vectorSimilarityThreshold : (globalSettings?.similarityThreshold || 0.3);
|
|
|
- const effectiveRerankThreshold = rerankSimilarityThreshold !== undefined ? rerankSimilarityThreshold : (globalSettings?.rerankSimilarityThreshold || 0.5);
|
|
|
- const effectiveEnableRerank = enableRerank !== undefined ? enableRerank : (globalSettings?.enableRerank ?? false);
|
|
|
- const effectiveEnableFullText = enableFullTextSearch !== undefined ? enableFullTextSearch : (globalSettings?.enableFullTextSearch ?? false);
|
|
|
- const effectiveEmbeddingId = embeddingModelId || globalSettings?.selectedEmbeddingId;
|
|
|
- const effectiveRerankId = rerankModelId || globalSettings?.selectedRerankId;
|
|
|
- const effectiveHybridWeight = globalSettings?.hybridVectorWeight ?? 0.7;
|
|
|
- const effectiveEnableQueryExpansion = enableQueryExpansion !== undefined ? enableQueryExpansion : (globalSettings?.enableQueryExpansion ?? false);
|
|
|
- const effectiveEnableHyDE = enableHyDE !== undefined ? enableHyDE : (globalSettings?.enableHyDE ?? false);
|
|
|
+ this.logger.log(`RAG Search: query="${query}", rerank=${enableRerank}, tenantId=${tenantId}`);
|
|
|
+
|
|
|
+ // 1. Get embedding for the query if needed
|
|
|
+ let queryVector: number[] = [];
|
|
|
+ if (embeddingModelId) {
|
|
|
+ try {
|
|
|
+ const vectors = await this.embeddingService.getEmbeddings([query], userId, embeddingModelId, tenantId);
|
|
|
+ queryVector = vectors[0];
|
|
|
+ } catch (error) {
|
|
|
+ this.logger.error('Failed to generate query embedding', error);
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- this.logger.log(
|
|
|
- `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}, QueryExpansion=${effectiveEnableQueryExpansion}, HyDE=${effectiveEnableHyDE}`,
|
|
|
+ // 2. Perform search via Elasticsearch
|
|
|
+ // Note: ElasticsearchService.hybridSearch supports explicitFileIds
|
|
|
+ const searchResults = await this.elasticsearchService.hybridSearch(
|
|
|
+ queryVector,
|
|
|
+ query,
|
|
|
+ userId,
|
|
|
+ topK * 2, // Get more results for reranking
|
|
|
+ 0.7, // Default vector weight
|
|
|
+ selectedGroups,
|
|
|
+ selectedFiles,
|
|
|
+ tenantId,
|
|
|
);
|
|
|
|
|
|
- try {
|
|
|
- // 1. Prepare query (expansion or HyDE)
|
|
|
- let queriesToSearch = [query];
|
|
|
+ let finalResults = searchResults;
|
|
|
|
|
|
- if (effectiveEnableHyDE) {
|
|
|
- const hydeDoc = await this.generateHyDE(query, userId);
|
|
|
- queriesToSearch = [hydeDoc]; // Use virtual document as query for HyDE
|
|
|
- } else if (effectiveEnableQueryExpansion) {
|
|
|
- const expanded = await this.expandQuery(query, userId);
|
|
|
- queriesToSearch = [...new Set([query, ...expanded])];
|
|
|
- }
|
|
|
-
|
|
|
- // Check if embedding model ID is provided
|
|
|
- if (!effectiveEmbeddingId) {
|
|
|
- throw new Error(this.i18nService.getMessage('embeddingModelIdNotProvided'));
|
|
|
- }
|
|
|
+ // 3. Apply Threshold
|
|
|
+ finalResults = finalResults.filter(r => r.score >= similarityThreshold);
|
|
|
|
|
|
- // 2. Parallel search for multiple queries
|
|
|
- const searchTasks = queriesToSearch.map(async (searchQuery) => {
|
|
|
- // Get query vector
|
|
|
- const queryEmbedding = await this.embeddingService.getEmbeddings(
|
|
|
- [searchQuery],
|
|
|
+ // 4. Perform Reranking if enabled
|
|
|
+ if (enableRerank && rerankModelId && finalResults.length > 0) {
|
|
|
+ try {
|
|
|
+ // Map search results to string array for reranking
|
|
|
+ const documentTexts = finalResults.map(r => r.content);
|
|
|
+
|
|
|
+ const rerankedPairs = await this.rerankService.rerank(
|
|
|
+ query,
|
|
|
+ documentTexts,
|
|
|
userId,
|
|
|
- effectiveEmbeddingId,
|
|
|
+ rerankModelId,
|
|
|
+ topK,
|
|
|
+ tenantId,
|
|
|
);
|
|
|
- const queryVector = queryEmbedding[0];
|
|
|
-
|
|
|
- // Select search strategy based on settings
|
|
|
- let results;
|
|
|
- if (effectiveEnableFullText) {
|
|
|
- results = await this.elasticsearchService.hybridSearch(
|
|
|
- queryVector,
|
|
|
- searchQuery,
|
|
|
- userId,
|
|
|
- effectiveTopK * 4,
|
|
|
- effectiveHybridWeight,
|
|
|
- undefined,
|
|
|
- effectiveFileIds,
|
|
|
- tenantId,
|
|
|
- );
|
|
|
- } else {
|
|
|
- let vectorSearchResults = await this.elasticsearchService.searchSimilar(
|
|
|
- queryVector,
|
|
|
- userId,
|
|
|
- effectiveTopK * 4,
|
|
|
- tenantId,
|
|
|
- );
|
|
|
- if (effectiveFileIds && effectiveFileIds.length > 0) {
|
|
|
- results = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId));
|
|
|
- } else {
|
|
|
- results = vectorSearchResults;
|
|
|
- }
|
|
|
- }
|
|
|
- return results;
|
|
|
- });
|
|
|
-
|
|
|
- const allResultsRaw = await Promise.all(searchTasks);
|
|
|
- let searchResults = this.deduplicateResults(allResultsRaw.flat());
|
|
|
-
|
|
|
- // Initial similarity filtering
|
|
|
- const initialCount = searchResults.length;
|
|
|
-
|
|
|
- // Log output
|
|
|
- searchResults.forEach((r, idx) => {
|
|
|
- this.logger.log(`Hit ${idx}: score=${r.score.toFixed(4)}, fileName=${r.fileName}`);
|
|
|
- });
|
|
|
|
|
|
- // Apply threshold filtering
|
|
|
- searchResults = searchResults.filter(r => r.score >= effectiveVectorThreshold);
|
|
|
- this.logger.log(`Initial hits: ${initialCount} -> filtered by vectorThreshold: ${searchResults.length}`);
|
|
|
-
|
|
|
- // 3. Rerank
|
|
|
- let finalResults = searchResults;
|
|
|
-
|
|
|
- if (effectiveEnableRerank && effectiveRerankId && searchResults.length > 0) {
|
|
|
- try {
|
|
|
- const docs = searchResults.map(r => r.content);
|
|
|
- const rerankedIndices = await this.rerankService.rerank(
|
|
|
- query,
|
|
|
- docs,
|
|
|
- userId,
|
|
|
- effectiveRerankId,
|
|
|
- effectiveTopK * 2 // Keep a bit more results
|
|
|
- );
|
|
|
-
|
|
|
- finalResults = rerankedIndices.map(r => {
|
|
|
- const originalItem = searchResults[r.index];
|
|
|
+ // Map reranked results back to RagSearchResult
|
|
|
+ finalResults = rerankedPairs
|
|
|
+ .filter(pair => pair.score >= rerankThreshold)
|
|
|
+ .map(pair => {
|
|
|
+ const originalResult = finalResults[pair.index];
|
|
|
return {
|
|
|
- ...originalItem,
|
|
|
- score: r.score, // Rerank score
|
|
|
- originalScore: originalItem.score // Original score
|
|
|
+ ...originalResult,
|
|
|
+ score: pair.score, // Update with rerank score
|
|
|
};
|
|
|
});
|
|
|
-
|
|
|
- // Filter after reranking
|
|
|
- const beforeRerankFilter = finalResults.length;
|
|
|
- finalResults = finalResults.filter(r => r.score >= effectiveRerankThreshold);
|
|
|
- this.logger.log(`After rerank: ${beforeRerankFilter} -> filtered by rerankThreshold: ${finalResults.length}`);
|
|
|
-
|
|
|
- } catch (error) {
|
|
|
- this.logger.warn(`Rerank failed, falling back to filtered vector search: ${error.message}`);
|
|
|
- // Fall back to filtered vector search results if rerank fails
|
|
|
- }
|
|
|
+ } catch (error) {
|
|
|
+ this.logger.error('Reranking failed, falling back to original results', error);
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- // Final result count limit
|
|
|
- finalResults = finalResults.slice(0, effectiveTopK);
|
|
|
-
|
|
|
- // 4. Convert to RAG result format
|
|
|
- const ragResults: RagSearchResult[] = finalResults.map((result) => ({
|
|
|
- content: result.content,
|
|
|
- fileName: result.fileName,
|
|
|
- score: result.score,
|
|
|
- originalScore: result.originalScore !== undefined ? result.originalScore : result.score,
|
|
|
- chunkIndex: result.chunkIndex,
|
|
|
- fileId: result.fileId,
|
|
|
- metadata: result.metadata,
|
|
|
- }));
|
|
|
+ // 5. Final Slice
|
|
|
+ return finalResults.slice(0, topK);
|
|
|
+ }
|
|
|
|
|
|
- return ragResults;
|
|
|
- } catch (error) {
|
|
|
- this.logger.error('RAG search failed:', error);
|
|
|
- return [];
|
|
|
- }
|
|
|
+ /**
|
|
|
+ * Extract unique document names as sources
|
|
|
+ */
|
|
|
+ extractSources(results: RagSearchResult[]): string[] {
|
|
|
+ const sources = new Set<string>();
|
|
|
+ results.forEach((r) => {
|
|
|
+ if (r.fileName) sources.add(r.fileName);
|
|
|
+ });
|
|
|
+ return Array.from(sources);
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Build the RAG prompt with instructions in the correct language
|
|
|
+ */
|
|
|
buildRagPrompt(
|
|
|
query: string,
|
|
|
searchResults: RagSearchResult[],
|
|
|
language: string = DEFAULT_LANGUAGE,
|
|
|
): string {
|
|
|
- const lang = language || DEFAULT_LANGUAGE;
|
|
|
+ const normalizedLang = this.i18nService.normalizeLanguage(language);
|
|
|
|
|
|
// Build context
|
|
|
let context = '';
|
|
|
if (searchResults.length === 0) {
|
|
|
- context = this.i18nService.getMessage('ragNoDocumentFound', lang);
|
|
|
+ context = normalizedLang === 'zh' ? '未找到相关信息。' : normalizedLang === 'ja' ? '関連情報が見つかりませんでした。' : 'No relevant information found.';
|
|
|
} else {
|
|
|
- // Group by file
|
|
|
- const fileGroups = new Map<string, RagSearchResult[]>();
|
|
|
- searchResults.forEach((result) => {
|
|
|
- if (!fileGroups.has(result.fileName)) {
|
|
|
- fileGroups.set(result.fileName, []);
|
|
|
- }
|
|
|
- fileGroups.get(result.fileName)!.push(result);
|
|
|
+ searchResults.forEach((result, index) => {
|
|
|
+ context += `[${index + 1}] File: ${result.fileName} (Score: ${result.score.toFixed(4)})\nContent: ${result.content}\n\n`;
|
|
|
});
|
|
|
-
|
|
|
- // Build context string
|
|
|
- const contextParts: string[] = [];
|
|
|
- fileGroups.forEach((chunks, fileName) => {
|
|
|
- contextParts.push(this.i18nService.formatMessage('ragSource', { fileName }, lang));
|
|
|
- chunks.forEach((chunk, index) => {
|
|
|
- contextParts.push(
|
|
|
- this.i18nService.formatMessage('ragSegment', {
|
|
|
- index: index + 1,
|
|
|
- score: chunk.score.toFixed(3)
|
|
|
- }, lang),
|
|
|
- );
|
|
|
- contextParts.push(chunk.content);
|
|
|
- contextParts.push('');
|
|
|
- });
|
|
|
- });
|
|
|
- context = contextParts.join('\n');
|
|
|
- }
|
|
|
-
|
|
|
- const langText = lang === 'zh' ? 'Chinese' : lang === 'en' ? 'English' : 'Japanese';
|
|
|
- const systemPrompt = this.i18nService.getMessage('ragSystemPrompt', lang);
|
|
|
- const rules = this.i18nService.formatMessage('ragRules', { lang: langText }, lang);
|
|
|
- const docContentHeader = this.i18nService.getMessage('ragDocumentContent', lang);
|
|
|
- const userQuestionHeader = this.i18nService.getMessage('ragUserQuestion', lang);
|
|
|
- const answerHeader = this.i18nService.getMessage('ragAnswer', lang);
|
|
|
-
|
|
|
- return `${systemPrompt}
|
|
|
-
|
|
|
-${rules}
|
|
|
-
|
|
|
-${docContentHeader}
|
|
|
-${context}
|
|
|
-
|
|
|
-${userQuestionHeader}
|
|
|
-${query}
|
|
|
-
|
|
|
-${answerHeader}`;
|
|
|
- }
|
|
|
-
|
|
|
- extractSources(searchResults: RagSearchResult[]): string[] {
|
|
|
- const uniqueFiles = new Set<string>();
|
|
|
- searchResults.forEach((result) => {
|
|
|
- uniqueFiles.add(result.fileName);
|
|
|
- });
|
|
|
- return Array.from(uniqueFiles);
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Deduplicate search results
|
|
|
- */
|
|
|
- private deduplicateResults(results: any[]): any[] {
|
|
|
- const unique = new Map<string, any>();
|
|
|
- results.forEach(r => {
|
|
|
- const key = `${r.fileId}_${r.chunkIndex}`;
|
|
|
- if (!unique.has(key) || unique.get(key)!.score < r.score) {
|
|
|
- unique.set(key, r);
|
|
|
- }
|
|
|
- });
|
|
|
- return Array.from(unique.values()).sort((a, b) => b.score - a.score);
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Expand query to generate variations
|
|
|
- */
|
|
|
- async expandQuery(query: string, userId: string, tenantId?: string): Promise<string[]> {
|
|
|
- try {
|
|
|
- const llm = await this.getInternalLlm(userId, tenantId || 'default');
|
|
|
- if (!llm) return [query];
|
|
|
-
|
|
|
- const userSettings = await this.userSettingService.getByUser(userId);
|
|
|
- const lang = userSettings.language || 'zh';
|
|
|
- const prompt = this.i18nService.formatMessage('queryExpansionPrompt', { query }, lang);
|
|
|
-
|
|
|
- const response = await llm.invoke(prompt);
|
|
|
- const content = String(response.content);
|
|
|
-
|
|
|
- const expandedQueries = content
|
|
|
- .split('\n')
|
|
|
- .map(q => q.trim())
|
|
|
- .filter(q => q.length > 0)
|
|
|
- .slice(0, 3); // Limit to maximum 3
|
|
|
-
|
|
|
- this.logger.log(`Query expanded: "${query}" -> [${expandedQueries.join(', ')}]`);
|
|
|
- return expandedQueries.length > 0 ? expandedQueries : [query];
|
|
|
- } catch (error) {
|
|
|
- this.logger.error('Query expansion failed:', error);
|
|
|
- return [query];
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Generate hypothetical document (HyDE)
|
|
|
- */
|
|
|
- async generateHyDE(query: string, userId: string, tenantId?: string): Promise<string> {
|
|
|
- try {
|
|
|
- const llm = await this.getInternalLlm(userId, tenantId || 'default');
|
|
|
- if (!llm) return query;
|
|
|
-
|
|
|
- const userSettings = await this.userSettingService.getByUser(userId);
|
|
|
- const lang = userSettings.language || 'zh';
|
|
|
- const prompt = this.i18nService.formatMessage('hydePrompt', { query }, lang);
|
|
|
-
|
|
|
- const response = await llm.invoke(prompt);
|
|
|
- const hydeDoc = String(response.content).trim();
|
|
|
-
|
|
|
- this.logger.log(`HyDE generated for: "${query}" (length: ${hydeDoc.length})`);
|
|
|
- return hydeDoc || query;
|
|
|
- } catch (error) {
|
|
|
- this.logger.error('HyDE generation failed:', error);
|
|
|
- return query;
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- /**
|
|
|
- * Get LLM instance for internal tasks
|
|
|
- */
|
|
|
- private async getInternalLlm(userId: string, tenantId: string): Promise<ChatOpenAI | null> {
|
|
|
- try {
|
|
|
- const models = await this.modelConfigService.findAll(userId, tenantId || 'default');
|
|
|
- const defaultLlm = models.find(m => m.type === 'llm' && m.isDefault && m.isEnabled !== false);
|
|
|
+ // Get localized prompt template from I18nService
|
|
|
+ const promptTemplate = this.i18nService.getPrompt(language, 'withContext', false);
|
|
|
|
|
|
- if (!defaultLlm) {
|
|
|
- this.logger.warn('No enabled LLM configured for internal tasks');
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
- return new ChatOpenAI({
|
|
|
- apiKey: defaultLlm.apiKey || 'ollama',
|
|
|
- temperature: 0.3,
|
|
|
- modelName: defaultLlm.modelId,
|
|
|
- configuration: {
|
|
|
- baseURL: defaultLlm.baseUrl || 'http://localhost:11434/v1',
|
|
|
- },
|
|
|
- });
|
|
|
- } catch (error) {
|
|
|
- this.logger.error('Failed to get internal LLM:', error);
|
|
|
- return null;
|
|
|
- }
|
|
|
+ return promptTemplate
|
|
|
+ .replace('{context}', context)
|
|
|
+ .replace('{history}', '') // History placeholders are usually handled in ChatService
|
|
|
+ .replace('{question}', query);
|
|
|
}
|
|
|
}
|