From 58f5dd48d9c64782e25d49f234e7076d39b0bd12 Mon Sep 17 00:00:00 2001 From: Achim Rohn Date: Fri, 9 Jan 2026 17:39:52 +0100 Subject: [PATCH] Add embedding support to Mistral client and mock --- llm/mistral.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/llm/mistral.go b/llm/mistral.go index 737f83e..b30bbb9 100644 --- a/llm/mistral.go +++ b/llm/mistral.go @@ -91,6 +91,35 @@ type UsageInfo struct { TotalTokens int `json:"total_tokens"` } +// EmbeddingRequest represents the request body for embeddings +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + EncodingFormat string `json:"encoding_format,omitempty"` +} + +// EmbeddingResponse represents the API response for embeddings +type EmbeddingResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Data []EmbeddingData `json:"data"` + Model string `json:"model"` + Usage EmbeddingUsageInfo `json:"usage"` +} + +// EmbeddingData represents a single embedding result +type EmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingUsageInfo represents token usage information for embeddings +type EmbeddingUsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + // Client represents a Mistral API client type MistralClientImpl struct { APIKey string @@ -145,8 +174,47 @@ func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*C return &result, nil } +// CreateEmbedding sends an embedding request to the API +func (c *MistralClientImpl) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) { + jsonData, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + request, err := http.NewRequest("POST", c.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Authorization", "Bearer "+c.APIKey) + + response, err := c.HTTPClient.Do(request) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body)) + } + + var result EmbeddingResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return &result, nil +} + type MistralClient interface { CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) + CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) } type MistralClientMock struct { @@ -170,3 +238,27 @@ func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*Ch return &response, nil //return nil, errors.New("something went wrong") } + +func (m MistralClientMock) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) { + ersteller.Debug("mocking the mistral client embedding") + // Return mock embeddings with dummy values + data := make([]EmbeddingData, len(req.Input)) + for i := range req.Input { + data[i] = EmbeddingData{ + Object: "embedding", + Embedding: make([]float64, 1024), // Mistral embedding models typically use 1024 dimensions + Index: i, + } + } + response := EmbeddingResponse{ + ID: "mock-embedding-id", + Object: "list", + Data: data, + Model: req.Model, + Usage: EmbeddingUsageInfo{ + PromptTokens: len(req.Input) * 10, + TotalTokens: len(req.Input) * 10, + }, + } + return &response, nil +}