embedding.service.ts 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import { Injectable, Logger } from '@nestjs/common';
  2. import { ConfigService } from '@nestjs/config';
  3. import { ModelConfigService } from '../model-config/model-config.service';
  4. export interface EmbeddingResponse {
  5. data: Array<{
  6. embedding: number[];
  7. index: number;
  8. }>;
  9. model: string;
  10. usage: {
  11. prompt_tokens: number;
  12. total_tokens: number;
  13. };
  14. }
  15. @Injectable()
  16. export class EmbeddingService {
  17. private readonly logger = new Logger(EmbeddingService.name);
  18. private readonly defaultDimensions: number;
  19. constructor(
  20. private modelConfigService: ModelConfigService,
  21. private configService: ConfigService,
  22. ) {
  23. this.defaultDimensions = parseInt(
  24. this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
  25. );
  26. this.logger.log(`デフォルトのベクトル次元が ${this.defaultDimensions} に設定されました`);
  27. }
  28. async getEmbeddings(
  29. texts: string[],
  30. userId: string,
  31. embeddingModelConfigId: string,
  32. tenantId?: string,
  33. ): Promise<number[][]> {
  34. this.logger.log(`${texts.length} 個のテキストに対して埋め込みベクトルを生成しています`);
  35. const modelConfig = await this.modelConfigService.findOne(
  36. embeddingModelConfigId,
  37. userId,
  38. tenantId || 'default',
  39. );
  40. if (!modelConfig || modelConfig.type !== 'embedding') {
  41. throw new Error(`埋め込みモデル設定 ${embeddingModelConfigId} が見つかりません`);
  42. }
  43. if (modelConfig.isEnabled === false) {
  44. throw new Error(`モデル ${modelConfig.name} は無効化されているため、埋め込みベクトルを生成できません`);
  45. }
  46. // APIキーはオプションです - ローカルモデルを許可します
  47. if (!modelConfig.baseUrl) {
  48. throw new Error(`モデル ${modelConfig.name} に baseUrl が設定されていません`);
  49. }
  50. // モデル名に基づいて最大バッチサイズを決定
  51. const maxBatchSize = this.getMaxBatchSizeForModel(modelConfig.modelId, modelConfig.maxBatchSize);
  52. // バッチサイズが制限を超える場合は分割して処理
  53. if (texts.length > maxBatchSize) {
  54. this.logger.log(
  55. `テキスト数 ${texts.length} がモデルのバッチ制限 ${maxBatchSize} を超えているため、分割処理します`
  56. );
  57. const allEmbeddings: number[][] = [];
  58. for (let i = 0; i < texts.length; i += maxBatchSize) {
  59. const batch = texts.slice(i, i + maxBatchSize);
  60. const batchEmbeddings = await this.getEmbeddingsForBatch(
  61. batch,
  62. userId,
  63. modelConfig,
  64. maxBatchSize
  65. );
  66. allEmbeddings.push(...batchEmbeddings);
  67. // APIレート制限対策のため、短い間隔で待機
  68. if (i + maxBatchSize < texts.length) {
  69. await new Promise(resolve => setTimeout(resolve, 100)); // 100ms待機
  70. }
  71. }
  72. return allEmbeddings;
  73. } else {
  74. // 通常処理(バッチサイズ以内)
  75. return await this.getEmbeddingsForBatch(
  76. texts,
  77. userId,
  78. modelConfig,
  79. maxBatchSize
  80. );
  81. }
  82. }
  83. /**
  84. * モデルIDに基づいて最大バッチサイズを決定
  85. */
  86. private getMaxBatchSizeForModel(modelId: string, configuredMaxBatchSize?: number): number {
  87. // モデル固有のバッチサイズ制限
  88. if (modelId.includes('text-embedding-004') || modelId.includes('text-embedding-v4') ||
  89. modelId.includes('text-embedding-ada-002')) {
  90. return Math.min(10, configuredMaxBatchSize || 100); // Googleの場合は10を上限
  91. } else if (modelId.includes('text-embedding-3') || modelId.includes('text-embedding-003')) {
  92. return Math.min(2048, configuredMaxBatchSize || 2048); // OpenAI v3は2048が上限
  93. } else {
  94. // デフォルトでは設定された最大バッチサイズか100の小さい方
  95. return Math.min(configuredMaxBatchSize || 100, 100);
  96. }
  97. }
  98. /**
  99. * 単一バッチの埋め込み処理
  100. */
  101. private async getEmbeddingsForBatch(
  102. texts: string[],
  103. userId: string,
  104. modelConfig: any,
  105. maxBatchSize: number,
  106. ): Promise<number[][]> {
  107. const apiUrl = modelConfig.baseUrl.endsWith('/embeddings')
  108. ? modelConfig.baseUrl
  109. : `${modelConfig.baseUrl}/embeddings`;
  110. let lastError;
  111. const MAX_RETRIES = 3;
  112. for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
  113. try {
  114. const controller = new AbortController();
  115. const timeoutId = setTimeout(() => {
  116. controller.abort();
  117. this.logger.error(`Embedding API timeout after 60s: ${apiUrl}`);
  118. }, 60000); // 60s timeout
  119. this.logger.log(`[モデル呼び出し] タイプ: Embedding, モデル: ${modelConfig.name} (${modelConfig.modelId}), ユーザー: ${userId}, テキスト数: ${texts.length}`);
  120. this.logger.log(`埋め込み API を呼び出し中 (試行 ${attempt}/${MAX_RETRIES}): ${apiUrl}`);
  121. let response;
  122. try {
  123. response = await fetch(apiUrl, {
  124. method: 'POST',
  125. headers: {
  126. 'Content-Type': 'application/json',
  127. Authorization: `Bearer ${modelConfig.apiKey}`,
  128. },
  129. body: JSON.stringify({
  130. encoding_format: 'float',
  131. input: texts,
  132. model: modelConfig.modelId,
  133. }),
  134. signal: controller.signal,
  135. });
  136. } finally {
  137. clearTimeout(timeoutId);
  138. }
  139. if (!response.ok) {
  140. const errorText = await response.text();
  141. // バッチサイズ制限エラーを検出
  142. if (errorText.includes('batch size is invalid') || errorText.includes('batch_size') ||
  143. errorText.includes('invalid') || errorText.includes('larger than')) {
  144. this.logger.warn(
  145. `バッチサイズ制限エラーが検出されました。バッチサイズを半分に分割して再試行します: ${maxBatchSize} -> ${Math.floor(maxBatchSize / 2)}`
  146. );
  147. // バッチをさらに小さな単位に分割して再試行
  148. if (texts.length > 1) {
  149. const midPoint = Math.floor(texts.length / 2);
  150. const firstHalf = texts.slice(0, midPoint);
  151. const secondHalf = texts.slice(midPoint);
  152. const firstResult = await this.getEmbeddingsForBatch(firstHalf, userId, modelConfig, Math.floor(maxBatchSize / 2));
  153. const secondResult = await this.getEmbeddingsForBatch(secondHalf, userId, modelConfig, Math.floor(maxBatchSize / 2));
  154. return [...firstResult, ...secondResult];
  155. }
  156. }
  157. // コンテキスト長の過剰エラーを検出
  158. if (errorText.includes('context length') || errorText.includes('exceeds')) {
  159. const avgLength = texts.reduce((s, t) => s + t.length, 0) / texts.length;
  160. const totalLength = texts.reduce((s, t) => s + t.length, 0);
  161. this.logger.error(
  162. `テキスト長が制限を超過しました: 入力 ${texts.length} 個のテキスト、` +
  163. `総計 ${totalLength} 文字、平均 ${Math.round(avgLength)} 文字、` +
  164. `モデル制限: ${modelConfig.maxInputTokens || 8192} tokens`
  165. );
  166. throw new Error(
  167. `テキスト長がモデルの制限を超えています。` +
  168. `現在: ${texts.length} 個のテキストで計 ${totalLength} 文字、` +
  169. `モデル制限: ${modelConfig.maxInputTokens || 8192} tokens。` +
  170. `アドバイス: チャンクサイズまたはバッチサイズを小さくしてください`
  171. );
  172. }
  173. // 429 (Too Many Requests) または 5xx (Server Error) の場合は再試行
  174. if (response.status === 429 || response.status >= 500) {
  175. this.logger.warn(`埋め込み API で一時的なエラーが発生しました (${response.status}): ${errorText}`);
  176. throw new Error(`API Error ${response.status}: ${errorText}`);
  177. }
  178. this.logger.error(`埋め込み API エラーの詳細: ${errorText}`);
  179. this.logger.error(`リクエストパラメータ: model=${modelConfig.modelId}, inputLength=${texts[0]?.length}`);
  180. throw new Error(`埋め込み API の呼び出しに失敗しました: ${response.statusText} - ${errorText}`);
  181. }
  182. const data: EmbeddingResponse = await response.json();
  183. const embeddings = data.data.map((item) => item.embedding);
  184. // 実際のレスポンスから次元を取得
  185. const actualDimensions = embeddings[0]?.length || this.defaultDimensions;
  186. this.logger.log(
  187. `${modelConfig.name} から ${embeddings.length} 個の埋め込みベクトルを取得しました。次元: ${actualDimensions}`,
  188. );
  189. return embeddings;
  190. } catch (error) {
  191. lastError = error;
  192. // 最後のアテンプトでなく、エラーが一時的と思われる場合(または堅牢性のために全て)は、待機後に再試行
  193. if (attempt < MAX_RETRIES) {
  194. const delay = Math.pow(2, attempt - 1) * 1000; // 1s, 2s, 4s
  195. this.logger.warn(`埋め込みリクエストが失敗しました。${delay}ms 後に再試行します: ${error.message}`);
  196. await new Promise(resolve => setTimeout(resolve, delay));
  197. continue;
  198. }
  199. }
  200. }
  201. throw lastError;
  202. }
  203. private getEstimatedDimensions(modelId: string): number {
  204. // 使用环境变量的默认维度
  205. return this.defaultDimensions;
  206. }
  207. }