Add embedding support to Mistral client and mock
This commit is contained in:
@@ -91,6 +91,35 @@ type UsageInfo struct {
|
|||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbeddingRequest represents the request body for embeddings
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input []string `json:"input"`
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingResponse represents the API response for embeddings
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []EmbeddingData `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage EmbeddingUsageInfo `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingData represents a single embedding result
|
||||||
|
type EmbeddingData struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingUsageInfo represents token usage information for embeddings
|
||||||
|
type EmbeddingUsageInfo struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
// Client represents a Mistral API client
|
// Client represents a Mistral API client
|
||||||
type MistralClientImpl struct {
|
type MistralClientImpl struct {
|
||||||
APIKey string
|
APIKey string
|
||||||
@@ -145,8 +174,47 @@ func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*C
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateEmbedding sends an embedding request to the API
|
||||||
|
func (c *MistralClientImpl) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshaling request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
request, err := http.NewRequest("POST", c.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header.Set("Content-Type", "application/json")
|
||||||
|
request.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
|
response, err := c.HTTPClient.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error making request: %v", err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error reading response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result EmbeddingResponse
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshaling response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
type MistralClient interface {
|
type MistralClient interface {
|
||||||
CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error)
|
CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error)
|
||||||
|
CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MistralClientMock struct {
|
type MistralClientMock struct {
|
||||||
@@ -170,3 +238,27 @@ func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*Ch
|
|||||||
return &response, nil
|
return &response, nil
|
||||||
//return nil, errors.New("something went wrong")
|
//return nil, errors.New("something went wrong")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m MistralClientMock) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
|
ersteller.Debug("mocking the mistral client embedding")
|
||||||
|
// Return mock embeddings with dummy values
|
||||||
|
data := make([]EmbeddingData, len(req.Input))
|
||||||
|
for i := range req.Input {
|
||||||
|
data[i] = EmbeddingData{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: make([]float64, 1024), // Mistral embedding models typically use 1024 dimensions
|
||||||
|
Index: i,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response := EmbeddingResponse{
|
||||||
|
ID: "mock-embedding-id",
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: req.Model,
|
||||||
|
Usage: EmbeddingUsageInfo{
|
||||||
|
PromptTokens: len(req.Input) * 10,
|
||||||
|
TotalTokens: len(req.Input) * 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user