Add embedding support to Mistral client and mock

This commit is contained in:
Achim Rohn
2026-01-09 17:39:52 +01:00
parent 7317f32fe4
commit 58f5dd48d9
+92
View File
@@ -91,6 +91,35 @@ type UsageInfo struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
} }
// EmbeddingRequest represents the request body for embeddings
type EmbeddingRequest struct {
Model string `json:"model"`
Input []string `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
}
// EmbeddingResponse represents the API response for embeddings
type EmbeddingResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Data []EmbeddingData `json:"data"`
Model string `json:"model"`
Usage EmbeddingUsageInfo `json:"usage"`
}
// EmbeddingData represents a single embedding result
type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
// EmbeddingUsageInfo represents token usage information for embeddings
type EmbeddingUsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Client represents a Mistral API client // Client represents a Mistral API client
type MistralClientImpl struct { type MistralClientImpl struct {
APIKey string APIKey string
@@ -145,8 +174,47 @@ func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*C
return &result, nil return &result, nil
} }
// CreateEmbedding sends an embedding request to the API
func (c *MistralClientImpl) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) {
jsonData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %v", err)
}
request, err := http.NewRequest("POST", c.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+c.APIKey)
response, err := c.HTTPClient.Do(request)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %v", err)
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body))
}
var result EmbeddingResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %v", err)
}
return &result, nil
}
type MistralClient interface { type MistralClient interface {
CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error)
CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error)
} }
type MistralClientMock struct { type MistralClientMock struct {
@@ -170,3 +238,27 @@ func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*Ch
return &response, nil return &response, nil
//return nil, errors.New("something went wrong") //return nil, errors.New("something went wrong")
} }
func (m MistralClientMock) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) {
ersteller.Debug("mocking the mistral client embedding")
// Return mock embeddings with dummy values
data := make([]EmbeddingData, len(req.Input))
for i := range req.Input {
data[i] = EmbeddingData{
Object: "embedding",
Embedding: make([]float64, 1024), // Mistral embedding models typically use 1024 dimensions
Index: i,
}
}
response := EmbeddingResponse{
ID: "mock-embedding-id",
Object: "list",
Data: data,
Model: req.Model,
Usage: EmbeddingUsageInfo{
PromptTokens: len(req.Input) * 10,
TotalTokens: len(req.Input) * 10,
},
}
return &response, nil
}