| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- import { Injectable, Logger, Inject, forwardRef } from '@nestjs/common';
- import { ConfigService } from '@nestjs/config';
- import { ElasticsearchService } from '../elasticsearch/elasticsearch.service';
- import { EmbeddingService } from '../knowledge-base/embedding.service';
- import { ModelConfigService } from '../model-config/model-config.service';
- import { RerankService } from './rerank.service';
- import { I18nService } from '../i18n/i18n.service';
- import { TenantService } from '../tenant/tenant.service';
- import { ChatOpenAI } from '@langchain/openai';
- import { ModelConfig } from '../types';
- import { UserSettingService } from '../user/user-setting.service';
- export interface RagSearchResult {
- content: string;
- fileName: string;
- score: number;
- chunkIndex: number;
- fileId?: string;
- originalScore?: number; // Original score before reranking (for debugging)
- metadata?: any;
- }
- @Injectable()
- export class RagService {
- private readonly logger = new Logger(RagService.name);
- private readonly defaultDimensions: number;
- constructor(
- @Inject(forwardRef(() => ElasticsearchService))
- private elasticsearchService: ElasticsearchService,
- private embeddingService: EmbeddingService,
- private modelConfigService: ModelConfigService,
- private rerankService: RerankService,
- private configService: ConfigService,
- private i18nService: I18nService,
- private tenantService: TenantService,
- private userSettingService: UserSettingService,
- ) {
- this.defaultDimensions = parseInt(
- this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
- );
- this.logger.log(`RAG service default vector dimensions: ${this.defaultDimensions}`);
- }
- async searchKnowledge(
- query: string,
- userId: string,
- topK: number = 5,
- vectorSimilarityThreshold: number = 0.3, // Vector search similarity threshold
- embeddingModelId?: string,
- enableFullTextSearch: boolean = false,
- enableRerank: boolean = false,
- rerankModelId?: string,
- selectedGroups?: string[],
- effectiveFileIds?: string[],
- rerankSimilarityThreshold: number = 0.5, // Rerank similarity threshold (default 0.5)
- tenantId?: string, // New
- enableQueryExpansion?: boolean,
- enableHyDE?: boolean,
- ): Promise<RagSearchResult[]> {
- // 1. Get organization settings
- const globalSettings = await this.tenantService.getSettings(tenantId || 'default');
- // Use global settings if parameters are not explicitly provided
- const effectiveTopK = topK || globalSettings?.topK || 5;
- const effectiveVectorThreshold = vectorSimilarityThreshold !== undefined ? vectorSimilarityThreshold : (globalSettings?.similarityThreshold || 0.3);
- const effectiveRerankThreshold = rerankSimilarityThreshold !== undefined ? rerankSimilarityThreshold : (globalSettings?.rerankSimilarityThreshold || 0.5);
- const effectiveEnableRerank = enableRerank !== undefined ? enableRerank : (globalSettings?.enableRerank ?? false);
- const effectiveEnableFullText = enableFullTextSearch !== undefined ? enableFullTextSearch : (globalSettings?.enableFullTextSearch ?? false);
- const effectiveEmbeddingId = embeddingModelId || globalSettings?.selectedEmbeddingId;
- const effectiveRerankId = rerankModelId || globalSettings?.selectedRerankId;
- const effectiveHybridWeight = globalSettings?.hybridVectorWeight ?? 0.7;
- const effectiveEnableQueryExpansion = enableQueryExpansion !== undefined ? enableQueryExpansion : (globalSettings?.enableQueryExpansion ?? false);
- const effectiveEnableHyDE = enableHyDE !== undefined ? enableHyDE : (globalSettings?.enableHyDE ?? false);
- this.logger.log(
- `RAG search: query="${query}", topK=${effectiveTopK}, vectorThreshold=${effectiveVectorThreshold}, rerankThreshold=${effectiveRerankThreshold}, hybridWeight=${effectiveHybridWeight}, QueryExpansion=${effectiveEnableQueryExpansion}, HyDE=${effectiveEnableHyDE}`,
- );
- try {
- // 1. Prepare query (expansion or HyDE)
- let queriesToSearch = [query];
- if (effectiveEnableHyDE) {
- const hydeDoc = await this.generateHyDE(query, userId);
- queriesToSearch = [hydeDoc]; // Use virtual document as query for HyDE
- } else if (effectiveEnableQueryExpansion) {
- const expanded = await this.expandQuery(query, userId);
- queriesToSearch = [...new Set([query, ...expanded])];
- }
- // Check if embedding model ID is provided
- if (!effectiveEmbeddingId) {
- throw new Error(this.i18nService.getMessage('embeddingModelIdNotProvided'));
- }
- // 2. Parallel search for multiple queries
- const searchTasks = queriesToSearch.map(async (searchQuery) => {
- // Get query vector
- const queryEmbedding = await this.embeddingService.getEmbeddings(
- [searchQuery],
- userId,
- effectiveEmbeddingId,
- );
- const queryVector = queryEmbedding[0];
- // Select search strategy based on settings
- let results;
- if (effectiveEnableFullText) {
- results = await this.elasticsearchService.hybridSearch(
- queryVector,
- searchQuery,
- userId,
- effectiveTopK * 4,
- effectiveHybridWeight,
- undefined,
- effectiveFileIds,
- tenantId,
- );
- } else {
- let vectorSearchResults = await this.elasticsearchService.searchSimilar(
- queryVector,
- userId,
- effectiveTopK * 4,
- tenantId,
- );
- if (effectiveFileIds && effectiveFileIds.length > 0) {
- results = vectorSearchResults.filter(r => effectiveFileIds.includes(r.fileId));
- } else {
- results = vectorSearchResults;
- }
- }
- return results;
- });
- const allResultsRaw = await Promise.all(searchTasks);
- let searchResults = this.deduplicateResults(allResultsRaw.flat());
- // Initial similarity filtering
- const initialCount = searchResults.length;
- // Log output
- searchResults.forEach((r, idx) => {
- this.logger.log(`Hit ${idx}: score=${r.score.toFixed(4)}, fileName=${r.fileName}`);
- });
- // Apply threshold filtering
- searchResults = searchResults.filter(r => r.score >= effectiveVectorThreshold);
- this.logger.log(`Initial hits: ${initialCount} -> filtered by vectorThreshold: ${searchResults.length}`);
- // 3. Rerank
- let finalResults = searchResults;
- if (effectiveEnableRerank && effectiveRerankId && searchResults.length > 0) {
- try {
- const docs = searchResults.map(r => r.content);
- const rerankedIndices = await this.rerankService.rerank(
- query,
- docs,
- userId,
- effectiveRerankId,
- effectiveTopK * 2 // Keep a bit more results
- );
- finalResults = rerankedIndices.map(r => {
- const originalItem = searchResults[r.index];
- return {
- ...originalItem,
- score: r.score, // Rerank score
- originalScore: originalItem.score // Original score
- };
- });
- // Filter after reranking
- const beforeRerankFilter = finalResults.length;
- finalResults = finalResults.filter(r => r.score >= effectiveRerankThreshold);
- this.logger.log(`After rerank: ${beforeRerankFilter} -> filtered by rerankThreshold: ${finalResults.length}`);
- } catch (error) {
- this.logger.warn(`Rerank failed, falling back to filtered vector search: ${error.message}`);
- // Fall back to filtered vector search results if rerank fails
- }
- }
- // Final result count limit
- finalResults = finalResults.slice(0, effectiveTopK);
- // 4. Convert to RAG result format
- const ragResults: RagSearchResult[] = finalResults.map((result) => ({
- content: result.content,
- fileName: result.fileName,
- score: result.score,
- originalScore: result.originalScore !== undefined ? result.originalScore : result.score,
- chunkIndex: result.chunkIndex,
- fileId: result.fileId,
- metadata: result.metadata,
- }));
- return ragResults;
- } catch (error) {
- this.logger.error('RAG search failed:', error);
- return [];
- }
- }
- buildRagPrompt(
- query: string,
- searchResults: RagSearchResult[],
- language: string = 'ja',
- ): string {
- const lang = language || 'ja';
- // Build context
- let context = '';
- if (searchResults.length === 0) {
- context = this.i18nService.getMessage('ragNoDocumentFound', lang);
- } else {
- // Group by file
- const fileGroups = new Map<string, RagSearchResult[]>();
- searchResults.forEach((result) => {
- if (!fileGroups.has(result.fileName)) {
- fileGroups.set(result.fileName, []);
- }
- fileGroups.get(result.fileName)!.push(result);
- });
- // Build context string
- const contextParts: string[] = [];
- fileGroups.forEach((chunks, fileName) => {
- contextParts.push(this.i18nService.formatMessage('ragSource', { fileName }, lang));
- chunks.forEach((chunk, index) => {
- contextParts.push(
- this.i18nService.formatMessage('ragSegment', {
- index: index + 1,
- score: chunk.score.toFixed(3)
- }, lang),
- );
- contextParts.push(chunk.content);
- contextParts.push('');
- });
- });
- context = contextParts.join('\n');
- }
- const langText = lang === 'zh' ? 'Chinese' : lang === 'en' ? 'English' : 'Japanese';
- const systemPrompt = this.i18nService.getMessage('ragSystemPrompt', lang);
- const rules = this.i18nService.formatMessage('ragRules', { lang: langText }, lang);
- const docContentHeader = this.i18nService.getMessage('ragDocumentContent', lang);
- const userQuestionHeader = this.i18nService.getMessage('ragUserQuestion', lang);
- const answerHeader = this.i18nService.getMessage('ragAnswer', lang);
- return `${systemPrompt}
- ${rules}
- ${docContentHeader}
- ${context}
- ${userQuestionHeader}
- ${query}
- ${answerHeader}`;
- }
- extractSources(searchResults: RagSearchResult[]): string[] {
- const uniqueFiles = new Set<string>();
- searchResults.forEach((result) => {
- uniqueFiles.add(result.fileName);
- });
- return Array.from(uniqueFiles);
- }
- /**
- * Deduplicate search results
- */
- private deduplicateResults(results: any[]): any[] {
- const unique = new Map<string, any>();
- results.forEach(r => {
- const key = `${r.fileId}_${r.chunkIndex}`;
- if (!unique.has(key) || unique.get(key)!.score < r.score) {
- unique.set(key, r);
- }
- });
- return Array.from(unique.values()).sort((a, b) => b.score - a.score);
- }
- /**
- * Expand query to generate variations
- */
- async expandQuery(query: string, userId: string, tenantId?: string): Promise<string[]> {
- try {
- const llm = await this.getInternalLlm(userId, tenantId || 'default');
- if (!llm) return [query];
- const userSettings = await this.userSettingService.getByUser(userId);
- const lang = userSettings.language || 'zh';
- const prompt = this.i18nService.formatMessage('queryExpansionPrompt', { query }, lang);
- const response = await llm.invoke(prompt);
- const content = String(response.content);
- const expandedQueries = content
- .split('\n')
- .map(q => q.trim())
- .filter(q => q.length > 0)
- .slice(0, 3); // Limit to maximum 3
- this.logger.log(`Query expanded: "${query}" -> [${expandedQueries.join(', ')}]`);
- return expandedQueries.length > 0 ? expandedQueries : [query];
- } catch (error) {
- this.logger.error('Query expansion failed:', error);
- return [query];
- }
- }
- /**
- * Generate hypothetical document (HyDE)
- */
- async generateHyDE(query: string, userId: string, tenantId?: string): Promise<string> {
- try {
- const llm = await this.getInternalLlm(userId, tenantId || 'default');
- if (!llm) return query;
- const userSettings = await this.userSettingService.getByUser(userId);
- const lang = userSettings.language || 'zh';
- const prompt = this.i18nService.formatMessage('hydePrompt', { query }, lang);
- const response = await llm.invoke(prompt);
- const hydeDoc = String(response.content).trim();
- this.logger.log(`HyDE generated for: "${query}" (length: ${hydeDoc.length})`);
- return hydeDoc || query;
- } catch (error) {
- this.logger.error('HyDE generation failed:', error);
- return query;
- }
- }
- /**
- * Get LLM instance for internal tasks
- */
- private async getInternalLlm(userId: string, tenantId: string): Promise<ChatOpenAI | null> {
- try {
- const models = await this.modelConfigService.findAll(userId, tenantId || 'default');
- const defaultLlm = models.find(m => m.type === 'llm' && m.isDefault && m.isEnabled !== false);
- if (!defaultLlm) {
- this.logger.warn('No enabled LLM configured for internal tasks');
- return null;
- }
- return new ChatOpenAI({
- apiKey: defaultLlm.apiKey || 'ollama',
- temperature: 0.3,
- modelName: defaultLlm.modelId,
- configuration: {
- baseURL: defaultLlm.baseUrl || 'http://localhost:11434/v1',
- },
- });
- } catch (error) {
- this.logger.error('Failed to get internal LLM:', error);
- return null;
- }
- }
- }
|