| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- import { Injectable, Logger } from '@nestjs/common';
- import { ConfigService } from '@nestjs/config';
- import { ModelConfigService } from '../model-config/model-config.service';
- import { I18nService } from '../i18n/i18n.service';
- export interface EmbeddingResponse {
- data: Array<{
- embedding: number[];
- index: number;
- }>;
- model: string;
- usage: {
- prompt_tokens: number;
- total_tokens: number;
- };
- }
- @Injectable()
- export class EmbeddingService {
- private readonly logger = new Logger(EmbeddingService.name);
- private readonly defaultDimensions: number;
- constructor(
- private modelConfigService: ModelConfigService,
- private configService: ConfigService,
- private i18nService: I18nService,
- ) {
- this.defaultDimensions = parseInt(
- this.configService.get<string>('DEFAULT_VECTOR_DIMENSIONS', '2560'),
- );
- this.logger.log(`Default vector dimensions set to ${this.defaultDimensions}`);
- }
- async getEmbeddings(
- texts: string[],
- userId: string,
- embeddingModelConfigId: string,
- tenantId?: string,
- ): Promise<number[][]> {
- this.logger.log(`Generating embeddings for ${texts.length} texts`);
- const modelConfig = await this.modelConfigService.findOne(
- embeddingModelConfigId,
- userId,
- tenantId || 'default',
- );
- if (!modelConfig || modelConfig.type !== 'embedding') {
- throw new Error(this.i18nService.formatMessage('embeddingModelNotFound', { id: embeddingModelConfigId }));
- }
- if (modelConfig.isEnabled === false) {
- throw new Error(`Model ${modelConfig.name} is disabled and cannot generate embeddings`);
- }
- // API key is optional - allows local models
- if (!modelConfig.baseUrl) {
- throw new Error(`Model ${modelConfig.name} does not have baseUrl configured`);
- }
- // Determine max batch size based on model name
- const maxBatchSize = this.getMaxBatchSizeForModel(modelConfig.modelId, modelConfig.maxBatchSize);
- // Split processing if batch size exceeds limit
- if (texts.length > maxBatchSize) {
- this.logger.log(
- `Splitting ${texts.length} texts into batches (model batch limit: ${maxBatchSize})`
- );
- const allEmbeddings: number[][] = [];
- for (let i = 0; i < texts.length; i += maxBatchSize) {
- const batch = texts.slice(i, i + maxBatchSize);
- const batchEmbeddings = await this.getEmbeddingsForBatch(
- batch,
- userId,
- modelConfig,
- maxBatchSize
- );
- allEmbeddings.push(...batchEmbeddings);
- // Wait briefly to avoid API rate limiting
- if (i + maxBatchSize < texts.length) {
- await new Promise(resolve => setTimeout(resolve, 100)); // Wait 100ms
- }
- }
- return allEmbeddings;
- } else {
- // Normal processing (within batch size)
- return await this.getEmbeddingsForBatch(
- texts,
- userId,
- modelConfig,
- maxBatchSize
- );
- }
- }
- /**
- * Determine max batch size based on model ID
- */
- private getMaxBatchSizeForModel(modelId: string, configuredMaxBatchSize?: number): number {
- // Model-specific batch size limits
- if (modelId.includes('text-embedding-004') || modelId.includes('text-embedding-v4') ||
- modelId.includes('text-embedding-ada-002')) {
- return Math.min(10, configuredMaxBatchSize || 100); // Google limit: 10
- } else if (modelId.includes('text-embedding-3') || modelId.includes('text-embedding-003')) {
- return Math.min(2048, configuredMaxBatchSize || 2048); // OpenAI v3 limit: 2048
- } else {
- // Default: smaller of configured max or 100
- return Math.min(configuredMaxBatchSize || 100, 100);
- }
- }
- /**
- * Process single batch embedding
- */
- private async getEmbeddingsForBatch(
- texts: string[],
- userId: string,
- modelConfig: any,
- maxBatchSize: number,
- ): Promise<number[][]> {
- const apiUrl = modelConfig.baseUrl.endsWith('/embeddings')
- ? modelConfig.baseUrl
- : `${modelConfig.baseUrl}/embeddings`;
- let lastError;
- const MAX_RETRIES = 3;
- for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
- try {
- const controller = new AbortController();
- const timeoutId = setTimeout(() => {
- controller.abort();
- this.logger.error(`Embedding API timeout after 60s: ${apiUrl}`);
- }, 60000); // 60s timeout
- this.logger.log(`[Model call] Type: Embedding, Model: ${modelConfig.name} (${modelConfig.modelId}), User: ${userId}, Text count: ${texts.length}`);
- this.logger.log(`Calling embedding API (attempt ${attempt}/${MAX_RETRIES}): ${apiUrl}`);
- let response;
- try {
- response = await fetch(apiUrl, {
- method: 'POST',
- headers: {
- 'Content-Type': 'application/json',
- Authorization: `Bearer ${modelConfig.apiKey}`,
- },
- body: JSON.stringify({
- encoding_format: 'float',
- input: texts,
- model: modelConfig.modelId,
- }),
- signal: controller.signal,
- });
- } finally {
- clearTimeout(timeoutId);
- }
- if (!response.ok) {
- const errorText = await response.text();
- // Detect batch size limit error
- if (errorText.includes('batch size is invalid') || errorText.includes('batch_size') ||
- errorText.includes('invalid') || errorText.includes('larger than')) {
- this.logger.warn(
- `Batch size limit error detected. Splitting batch in half and retrying: ${maxBatchSize} -> ${Math.floor(maxBatchSize / 2)}`
- );
- // Split batch into smaller units and retry
- if (texts.length > 1) {
- const midPoint = Math.floor(texts.length / 2);
- const firstHalf = texts.slice(0, midPoint);
- const secondHalf = texts.slice(midPoint);
- const firstResult = await this.getEmbeddingsForBatch(firstHalf, userId, modelConfig, Math.floor(maxBatchSize / 2));
- const secondResult = await this.getEmbeddingsForBatch(secondHalf, userId, modelConfig, Math.floor(maxBatchSize / 2));
- return [...firstResult, ...secondResult];
- }
- }
- // Detect context length excess error
- if (errorText.includes('context length') || errorText.includes('exceeds')) {
- const avgLength = texts.reduce((s, t) => s + t.length, 0) / texts.length;
- const totalLength = texts.reduce((s, t) => s + t.length, 0);
- this.logger.error(
- `Text length exceeds limit: ${texts.length} texts, ` +
- `total ${totalLength} characters, average ${Math.round(avgLength)} characters, ` +
- `model limit: ${modelConfig.maxInputTokens || 8192} tokens`
- );
- throw new Error(
- `Text length exceeds model limit. ` +
- `Current: ${texts.length} texts with total ${totalLength} characters, ` +
- `model limit: ${modelConfig.maxInputTokens || 8192} tokens. ` +
- `Advice: Reduce chunk size or batch size`
- );
- }
- // Retry on 429 (Too Many Requests) or 5xx (Server Error)
- if (response.status === 429 || response.status >= 500) {
- this.logger.warn(`Temporary error from embedding API (${response.status}): ${errorText}`);
- throw new Error(`API Error ${response.status}: ${errorText}`);
- }
- this.logger.error(`Embedding API error details: ${errorText}`);
- this.logger.error(`Request parameters: model=${modelConfig.modelId}, inputLength=${texts[0]?.length}`);
- throw new Error(`Embedding API call failed: ${response.statusText} - ${errorText}`);
- }
- const data: EmbeddingResponse = await response.json();
- const embeddings = data.data.map((item) => item.embedding);
- // Get dimensions from actual response
- const actualDimensions = embeddings[0]?.length || this.defaultDimensions;
- this.logger.log(
- `Got ${embeddings.length} embedding vectors from ${modelConfig.name}. Dimensions: ${actualDimensions}`,
- );
- return embeddings;
- } catch (error) {
- lastError = error;
- // If not the last attempt and error appears temporary (or for robustness on all), retry after waiting
- if (attempt < MAX_RETRIES) {
- const delay = Math.pow(2, attempt - 1) * 1000; // 1s, 2s, 4s
- this.logger.warn(`Embedding request failed. Retrying after ${delay}ms: ${error.message}`);
- await new Promise(resolve => setTimeout(resolve, delay));
- continue;
- }
- }
- }
- throw lastError;
- }
- private getEstimatedDimensions(modelId: string): number {
- // Use default dimensions from environment variable
- return this.defaultDimensions;
- }
- }
|