| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import { Injectable, Logger } from '@nestjs/common';
- import { ConfigService } from '@nestjs/config';
- import { ModelConfigService } from '../model-config/model-config.service';
- import { ModelType } from '../types';
- import axios from 'axios';
- export interface RerankResult {
- index: number;
- relevance_score: number;
- document?: string; // Optional, some APIs return it
- }
- @Injectable()
- export class RerankService {
- private readonly logger = new Logger(RerankService.name);
- constructor(
- private modelConfigService: ModelConfigService,
- private configService: ConfigService,
- ) { }
- /**
- * リランクの実行
- * @param query ユーザーのクエリ
- * @param documents 候補ドキュメントリスト
- * @param userId ユーザーID
- * @param rerankModelId 選択された Rerank モデル設定ID
- * @param topN 返す結果の数 (上位 N 個)
- */
- async rerank(
- query: string,
- documents: string[],
- userId: string,
- rerankModelId: string,
- topN: number = 5,
- tenantId?: string,
- ): Promise<{ index: number; score: number }[]> {
- if (!documents || documents.length === 0) {
- return [];
- }
- let modelConfig;
- try {
- // 1. モデル設定の取得
- modelConfig = await this.modelConfigService.findOne(rerankModelId, userId, tenantId || 'default');
- if (!modelConfig || modelConfig.type !== ModelType.RERANK) {
- this.logger.warn(`Invalid rerank model config: ${rerankModelId}`);
- // Fallback: return original order with dummy scores
- return documents.map((_, index) => ({ index, score: 0.99 - (index * 0.01) }));
- }
- const apiKey = modelConfig.apiKey;
- const baseUrl = modelConfig.baseUrl || ''; // e.g. https://api.siliconflow.cn/v1
- const modelName = modelConfig.modelId; // e.g. BAAI/bge-reranker-v2-m3
- this.logger.log(`Reranking ${documents.length} docs with model ${modelName} at ${baseUrl}`);
- // 2. API リクエストの構築 (OpenAI/SiliconFlow 互換 Rerank API)
- // 注: 標準の OpenAI API には /rerank はありませんが、SiliconFlow/Jina/Cohere は同様の構造を使用しています
- // SiliconFlow 形式: POST /v1/rerank { model, query, documents, top_n }
- const endpoint = baseUrl.replace(/\/+$/, '');
- // Log the exact endpoint being called
- this.logger.log(`Calling Rerank API: ${endpoint} (Model: ${modelName})`);
- const response = await axios.post(
- endpoint,
- {
- model: modelName,
- query: query,
- documents: documents,
- top_n: topN,
- return_documents: false // We only need indices and scores
- },
- {
- headers: {
- 'Authorization': `Bearer ${apiKey}`,
- 'Content-Type': 'application/json',
- },
- timeout: 10000,
- }
- );
- // 3. レスポンスの解析
- // Expected response format (SiliconFlow/Cohere):
- // { results: [ { index: 0, relevance_score: 0.98 }, ... ] }
- if (response.data && response.data.results) {
- const results = response.data.results as RerankResult[];
- return results.map(r => ({
- index: r.index,
- score: r.relevance_score
- })).sort((a, b) => b.score - a.score); // Ensure sorted
- } else {
- this.logger.error('Unexpected rerank API response structure', response.data);
- return documents.map((_, index) => ({ index, score: 0 }));
- }
- } catch (error) {
- let errorMessage = error.message;
- if (error.code === 'EPROTO' || error.message.includes('wrong version number')) {
- errorMessage = `${error.message}. This often happens when using HTTPS to connect to an HTTP server. Please check your model Base URL protocol (http vs https).`;
- } else if (error.response?.status === 404) {
- const endpoint = modelConfig?.baseUrl?.replace(/\/+$/, '');
- errorMessage = `Endpoint not found (404). Tried: ${endpoint}. Please check if your Base URL is correct. (Note: We use the Base URL exactly as provided for Rerank models).`;
- }
- this.logger.error(`Rerank failed: ${errorMessage}`, error.response?.data);
- // Fallback on error: return original order
- return documents.map((_, index) => ({ index, score: 0 }));
- }
- }
- }
|