rerank.service.ts 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import { Injectable, Logger } from '@nestjs/common';
  2. import { ConfigService } from '@nestjs/config';
  3. import { ModelConfigService } from '../model-config/model-config.service';
  4. import { ModelType } from '../types';
  5. import axios from 'axios';
  6. export interface RerankResult {
  7. index: number;
  8. relevance_score: number;
  9. document?: string; // Optional, some APIs return it
  10. }
  11. @Injectable()
  12. export class RerankService {
  13. private readonly logger = new Logger(RerankService.name);
  14. constructor(
  15. private modelConfigService: ModelConfigService,
  16. private configService: ConfigService,
  17. ) { }
  18. /**
  19. * リランクの実行
  20. * @param query ユーザーのクエリ
  21. * @param documents 候補ドキュメントリスト
  22. * @param userId ユーザーID
  23. * @param rerankModelId 選択された Rerank モデル設定ID
  24. * @param topN 返す結果の数 (上位 N 個)
  25. */
  26. async rerank(
  27. query: string,
  28. documents: string[],
  29. userId: string,
  30. rerankModelId: string,
  31. topN: number = 5,
  32. tenantId?: string,
  33. ): Promise<{ index: number; score: number }[]> {
  34. if (!documents || documents.length === 0) {
  35. return [];
  36. }
  37. let modelConfig;
  38. try {
  39. // 1. モデル設定の取得
  40. modelConfig = await this.modelConfigService.findOne(rerankModelId, userId, tenantId || 'default');
  41. if (!modelConfig || modelConfig.type !== ModelType.RERANK) {
  42. this.logger.warn(`Invalid rerank model config: ${rerankModelId}`);
  43. // Fallback: return original order with dummy scores
  44. return documents.map((_, index) => ({ index, score: 0.99 - (index * 0.01) }));
  45. }
  46. const apiKey = modelConfig.apiKey;
  47. const baseUrl = modelConfig.baseUrl || ''; // e.g. https://api.siliconflow.cn/v1
  48. const modelName = modelConfig.modelId; // e.g. BAAI/bge-reranker-v2-m3
  49. this.logger.log(`Reranking ${documents.length} docs with model ${modelName} at ${baseUrl}`);
  50. // 2. API リクエストの構築 (OpenAI/SiliconFlow 互換 Rerank API)
  51. // 注: 標準の OpenAI API には /rerank はありませんが、SiliconFlow/Jina/Cohere は同様の構造を使用しています
  52. // SiliconFlow 形式: POST /v1/rerank { model, query, documents, top_n }
  53. const endpoint = baseUrl.replace(/\/+$/, '');
  54. // Log the exact endpoint being called
  55. this.logger.log(`Calling Rerank API: ${endpoint} (Model: ${modelName})`);
  56. const response = await axios.post(
  57. endpoint,
  58. {
  59. model: modelName,
  60. query: query,
  61. documents: documents,
  62. top_n: topN,
  63. return_documents: false // We only need indices and scores
  64. },
  65. {
  66. headers: {
  67. 'Authorization': `Bearer ${apiKey}`,
  68. 'Content-Type': 'application/json',
  69. },
  70. timeout: 10000,
  71. }
  72. );
  73. // 3. レスポンスの解析
  74. // Expected response format (SiliconFlow/Cohere):
  75. // { results: [ { index: 0, relevance_score: 0.98 }, ... ] }
  76. if (response.data && response.data.results) {
  77. const results = response.data.results as RerankResult[];
  78. return results.map(r => ({
  79. index: r.index,
  80. score: r.relevance_score
  81. })).sort((a, b) => b.score - a.score); // Ensure sorted
  82. } else {
  83. this.logger.error('Unexpected rerank API response structure', response.data);
  84. return documents.map((_, index) => ({ index, score: 0 }));
  85. }
  86. } catch (error) {
  87. let errorMessage = error.message;
  88. if (error.code === 'EPROTO' || error.message.includes('wrong version number')) {
  89. errorMessage = `${error.message}. This often happens when using HTTPS to connect to an HTTP server. Please check your model Base URL protocol (http vs https).`;
  90. } else if (error.response?.status === 404) {
  91. const endpoint = modelConfig?.baseUrl?.replace(/\/+$/, '');
  92. errorMessage = `Endpoint not found (404). Tried: ${endpoint}. Please check if your Base URL is correct. (Note: We use the Base URL exactly as provided for Rerank models).`;
  93. }
  94. this.logger.error(`Rerank failed: ${errorMessage}`, error.response?.data);
  95. // Fallback on error: return original order
  96. return documents.map((_, index) => ({ index, score: 0 }));
  97. }
  98. }
  99. }