embedding.service.ts 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import { Injectable, Logger } from '@nestjs/common';
  2. import { ConfigService } from '@nestjs/config';
  3. import { ModelConfigService } from '../model-config/model-config.service';
  4. import { I18nService } from '../i18n/i18n.service';
  5. export interface EmbeddingResponse {
  6. data: Array<{
  7. embedding: number[];
  8. index: number;
  9. }>;
  10. model: string;
  11. usage: {
  12. prompt_tokens: number;
  13. total_tokens: number;
  14. };
  15. }
  16. @Injectable()
  17. export class EmbeddingService {
  18. private readonly logger = new Logger(EmbeddingService.name);
  19. private readonly defaultDimensions: number;
  20. constructor(
  21. private modelConfigService: ModelConfigService,
  22. private configService: ConfigService,
  23. private i18nService: I18nService,
  24. ) {
  25. this.defaultDimensions = parseInt(
  26. this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
  27. );
  28. this.logger.log(`Default vector dimensions set to ${this.defaultDimensions}`);
  29. }
  30. async getEmbeddings(
  31. texts: string[],
  32. userId: string,
  33. embeddingModelConfigId: string,
  34. tenantId?: string,
  35. ): Promise<number[][]> {
  36. this.logger.log(`Generating embeddings for ${texts.length} texts`);
  37. const modelConfig = await this.modelConfigService.findOne(
  38. embeddingModelConfigId,
  39. userId,
  40. tenantId || 'default',
  41. );
  42. if (!modelConfig || modelConfig.type !== 'embedding') {
  43. throw new Error(this.i18nService.formatMessage('embeddingModelNotFound', { id: embeddingModelConfigId }));
  44. }
  45. if (modelConfig.isEnabled === false) {
  46. throw new Error(`Model ${modelConfig.name} is disabled and cannot generate embeddings`);
  47. }
  48. // API key is optional - allows local models
  49. if (!modelConfig.baseUrl) {
  50. throw new Error(`Model ${modelConfig.name} does not have baseUrl configured`);
  51. }
  52. // Determine max batch size based on model name
  53. const maxBatchSize = this.getMaxBatchSizeForModel(modelConfig.modelId, modelConfig.maxBatchSize);
  54. // Split processing if batch size exceeds limit
  55. if (texts.length > maxBatchSize) {
  56. this.logger.log(
  57. `Splitting ${texts.length} texts into batches (model batch limit: ${maxBatchSize})`
  58. );
  59. const allEmbeddings: number[][] = [];
  60. for (let i = 0; i < texts.length; i += maxBatchSize) {
  61. const batch = texts.slice(i, i + maxBatchSize);
  62. const batchEmbeddings = await this.getEmbeddingsForBatch(
  63. batch,
  64. userId,
  65. modelConfig,
  66. maxBatchSize
  67. );
  68. allEmbeddings.push(...batchEmbeddings);
  69. // Wait briefly to avoid API rate limiting
  70. if (i + maxBatchSize < texts.length) {
  71. await new Promise(resolve => setTimeout(resolve, 100)); // Wait 100ms
  72. }
  73. }
  74. return allEmbeddings;
  75. } else {
  76. // Normal processing (within batch size)
  77. return await this.getEmbeddingsForBatch(
  78. texts,
  79. userId,
  80. modelConfig,
  81. maxBatchSize
  82. );
  83. }
  84. }
  85. /**
  86. * Determine max batch size based on model ID
  87. */
  88. private getMaxBatchSizeForModel(modelId: string, configuredMaxBatchSize?: number): number {
  89. // Model-specific batch size limits
  90. if (modelId.includes('text-embedding-004') || modelId.includes('text-embedding-v4') ||
  91. modelId.includes('text-embedding-ada-002')) {
  92. return Math.min(10, configuredMaxBatchSize || 100); // Google limit: 10
  93. } else if (modelId.includes('text-embedding-3') || modelId.includes('text-embedding-003')) {
  94. return Math.min(2048, configuredMaxBatchSize || 2048); // OpenAI v3 limit: 2048
  95. } else {
  96. // Default: smaller of configured max or 100
  97. return Math.min(configuredMaxBatchSize || 100, 100);
  98. }
  99. }
  100. /**
  101. * Process single batch embedding
  102. */
  103. private async getEmbeddingsForBatch(
  104. texts: string[],
  105. userId: string,
  106. modelConfig: any,
  107. maxBatchSize: number,
  108. ): Promise<number[][]> {
  109. const apiUrl = modelConfig.baseUrl.endsWith('/embeddings')
  110. ? modelConfig.baseUrl
  111. : `${modelConfig.baseUrl}/embeddings`;
  112. let lastError;
  113. const MAX_RETRIES = 3;
  114. for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
  115. try {
  116. const controller = new AbortController();
  117. const timeoutId = setTimeout(() => {
  118. controller.abort();
  119. this.logger.error(`Embedding API timeout after 60s: ${apiUrl}`);
  120. }, 60000); // 60s timeout
  121. this.logger.log(`[Model call] Type: Embedding, Model: ${modelConfig.name} (${modelConfig.modelId}), User: ${userId}, Text count: ${texts.length}`);
  122. this.logger.log(`Calling embedding API (attempt ${attempt}/${MAX_RETRIES}): ${apiUrl}`);
  123. let response;
  124. try {
  125. response = await fetch(apiUrl, {
  126. method: 'POST',
  127. headers: {
  128. 'Content-Type': 'application/json',
  129. Authorization: `Bearer ${modelConfig.apiKey}`,
  130. },
  131. body: JSON.stringify({
  132. encoding_format: 'float',
  133. input: texts,
  134. model: modelConfig.modelId,
  135. }),
  136. signal: controller.signal,
  137. });
  138. } finally {
  139. clearTimeout(timeoutId);
  140. }
  141. if (!response.ok) {
  142. const errorText = await response.text();
  143. // Detect batch size limit error
  144. if (errorText.includes('batch size is invalid') || errorText.includes('batch_size') ||
  145. errorText.includes('invalid') || errorText.includes('larger than')) {
  146. this.logger.warn(
  147. `Batch size limit error detected. Splitting batch in half and retrying: ${maxBatchSize} -> ${Math.floor(maxBatchSize / 2)}`
  148. );
  149. // Split batch into smaller units and retry
  150. if (texts.length > 1) {
  151. const midPoint = Math.floor(texts.length / 2);
  152. const firstHalf = texts.slice(0, midPoint);
  153. const secondHalf = texts.slice(midPoint);
  154. const firstResult = await this.getEmbeddingsForBatch(firstHalf, userId, modelConfig, Math.floor(maxBatchSize / 2));
  155. const secondResult = await this.getEmbeddingsForBatch(secondHalf, userId, modelConfig, Math.floor(maxBatchSize / 2));
  156. return [...firstResult, ...secondResult];
  157. }
  158. }
  159. // Detect context length excess error
  160. if (errorText.includes('context length') || errorText.includes('exceeds')) {
  161. const avgLength = texts.reduce((s, t) => s + t.length, 0) / texts.length;
  162. const totalLength = texts.reduce((s, t) => s + t.length, 0);
  163. this.logger.error(
  164. `Text length exceeds limit: ${texts.length} texts, ` +
  165. `total ${totalLength} characters, average ${Math.round(avgLength)} characters, ` +
  166. `model limit: ${modelConfig.maxInputTokens || 8192} tokens`
  167. );
  168. throw new Error(
  169. `Text length exceeds model limit. ` +
  170. `Current: ${texts.length} texts with total ${totalLength} characters, ` +
  171. `model limit: ${modelConfig.maxInputTokens || 8192} tokens. ` +
  172. `Advice: Reduce chunk size or batch size`
  173. );
  174. }
  175. // Retry on 429 (Too Many Requests) or 5xx (Server Error)
  176. if (response.status === 429 || response.status >= 500) {
  177. this.logger.warn(`Temporary error from embedding API (${response.status}): ${errorText}`);
  178. throw new Error(`API Error ${response.status}: ${errorText}`);
  179. }
  180. this.logger.error(`Embedding API error details: ${errorText}`);
  181. this.logger.error(`Request parameters: model=${modelConfig.modelId}, inputLength=${texts[0]?.length}`);
  182. throw new Error(`Embedding API call failed: ${response.statusText} - ${errorText}`);
  183. }
  184. const data: EmbeddingResponse = await response.json();
  185. const embeddings = data.data.map((item) => item.embedding);
  186. // Get dimensions from actual response
  187. const actualDimensions = embeddings[0]?.length || this.defaultDimensions;
  188. this.logger.log(
  189. `Got ${embeddings.length} embedding vectors from ${modelConfig.name}. Dimensions: ${actualDimensions}`,
  190. );
  191. return embeddings;
  192. } catch (error) {
  193. lastError = error;
  194. // If not the last attempt and error appears temporary (or for robustness on all), retry after waiting
  195. if (attempt < MAX_RETRIES) {
  196. const delay = Math.pow(2, attempt - 1) * 1000; // 1s, 2s, 4s
  197. this.logger.warn(`Embedding request failed. Retrying after ${delay}ms: ${error.message}`);
  198. await new Promise(resolve => setTimeout(resolve, delay));
  199. continue;
  200. }
  201. }
  202. }
  203. throw lastError;
  204. }
  205. private getEstimatedDimensions(modelId: string): number {
  206. // Use default dimensions from environment variable
  207. return this.defaultDimensions;
  208. }
  209. }