278 lines
8.9 KiB
Go
278 lines
8.9 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"git.gorlug.de/code/ersteller"
|
|
)
|
|
|
|
type ChatCompletionRequestResponseFormatType string
|
|
|
|
const (
|
|
ChatCompletionRequestResponseFormatText ChatCompletionRequestResponseFormatType = "text"
|
|
ChatCompletionRequestResponseFormatJSON ChatCompletionRequestResponseFormatType = "json_object"
|
|
)
|
|
|
|
type ChatCompletionRequestResponseFormat struct {
|
|
Type ChatCompletionRequestResponseFormatType `json:"type,omitempty"`
|
|
}
|
|
|
|
// ChatCompletionRequest represents the request body for chat completions
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
TopP *float64 `json:"top_p,omitempty"`
|
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
Stop interface{} `json:"stop,omitempty"`
|
|
RandomSeed *int `json:"random_seed,omitempty"`
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
|
N *int `json:"n,omitempty"`
|
|
SafePrompt bool `json:"safe_prompt"`
|
|
ResponseFormat ChatCompletionRequestResponseFormat `json:"response_format,omitempty"`
|
|
}
|
|
|
|
// ContentItem represents a single content item (text or image)
|
|
type ContentItem struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL string `json:"image_url,omitempty"`
|
|
}
|
|
|
|
// Message represents a chat message
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content interface{} `json:"content"`
|
|
}
|
|
|
|
// Tool represents a function tool
|
|
type Tool struct {
|
|
Type string `json:"type"`
|
|
Function Function `json:"function"`
|
|
}
|
|
|
|
// Function represents a function definition
|
|
type Function struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
Parameters map[string]interface{} `json:"parameters"`
|
|
Strict bool `json:"strict,omitempty"`
|
|
}
|
|
|
|
// ChatCompletionResponse represents the API response
|
|
type ChatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []ChatCompletionChoice `json:"choices"`
|
|
Usage UsageInfo `json:"usage"`
|
|
}
|
|
|
|
// ChatCompletionChoice represents a completion choice
|
|
type ChatCompletionChoice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
// UsageInfo represents token usage information
|
|
type UsageInfo struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
// EmbeddingRequest represents the request body for embeddings
|
|
type EmbeddingRequest struct {
|
|
Model string `json:"model"`
|
|
Input []string `json:"input"`
|
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
|
}
|
|
|
|
// EmbeddingResponse represents the API response for embeddings
|
|
type EmbeddingResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Data []EmbeddingData `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage EmbeddingUsageInfo `json:"usage"`
|
|
}
|
|
|
|
// EmbeddingData represents a single embedding result
|
|
type EmbeddingData struct {
|
|
Object string `json:"object"`
|
|
Embedding []float64 `json:"embedding"`
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
// EmbeddingUsageInfo represents token usage information for embeddings
|
|
type EmbeddingUsageInfo struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
// Client represents a Mistral API client
|
|
type MistralClientImpl struct {
|
|
APIKey string
|
|
BaseURL string
|
|
HTTPClient *http.Client
|
|
}
|
|
|
|
type MistralClientOption = func(*MistralClientImpl)
|
|
|
|
func BaseUrlOption(url string) MistralClientOption {
|
|
return func(c *MistralClientImpl) {
|
|
c.BaseURL = url
|
|
}
|
|
}
|
|
|
|
func NewMistralClient(apiKey string, options ...MistralClientOption) *MistralClientImpl {
|
|
client := &MistralClientImpl{
|
|
APIKey: apiKey,
|
|
BaseURL: "https://api.mistral.ai",
|
|
HTTPClient: &http.Client{},
|
|
}
|
|
for _, option := range options {
|
|
option(client)
|
|
}
|
|
return client
|
|
}
|
|
|
|
// CreateChatCompletion sends a chat completion request to the API
|
|
func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error marshaling request: %v", err)
|
|
}
|
|
|
|
request, err := http.NewRequest("POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating request: %v", err)
|
|
}
|
|
|
|
request.Header.Set("Content-Type", "application/json")
|
|
request.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
ersteller.Debug("authorization", "Bearer ", c.APIKey)
|
|
ersteller.Debug("base url", c.BaseURL)
|
|
|
|
response, err := c.HTTPClient.Do(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error making request: %v", err)
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
body, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading response body: %v", err)
|
|
}
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body))
|
|
}
|
|
|
|
var result ChatCompletionResponse
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
return nil, fmt.Errorf("error unmarshaling response: %v", err)
|
|
}
|
|
|
|
return &result, nil
|
|
}
|
|
|
|
// CreateEmbedding sends an embedding request to the API
|
|
func (c *MistralClientImpl) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error marshaling request: %v", err)
|
|
}
|
|
|
|
request, err := http.NewRequest("POST", c.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating request: %v", err)
|
|
}
|
|
|
|
request.Header.Set("Content-Type", "application/json")
|
|
request.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
|
|
response, err := c.HTTPClient.Do(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error making request: %v", err)
|
|
}
|
|
defer response.Body.Close()
|
|
|
|
body, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading response body: %v", err)
|
|
}
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body))
|
|
}
|
|
|
|
var result EmbeddingResponse
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
return nil, fmt.Errorf("error unmarshaling response: %v", err)
|
|
}
|
|
|
|
return &result, nil
|
|
}
|
|
|
|
type MistralClient interface {
|
|
CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error)
|
|
CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error)
|
|
}
|
|
|
|
type MistralClientMock struct {
|
|
MockedContent string
|
|
}
|
|
|
|
func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
|
ersteller.Debug("mocking the mistral client")
|
|
if m.MockedContent == "" {
|
|
m.MockedContent = `{"name":"Meetup app","description":"An app to make notes for meetups","entities":[{"name":"Meetup","fields":[{"name":"name","type":"string"},{"name":"description","type":"string","longText":true},{"name":"date","type":"time"}],"entities":[{"name":"Notes","fields":[{"name":"title","type":"string"},{"name":"description","type":"string","longText":true}]}]}]}`
|
|
}
|
|
response := ChatCompletionResponse{
|
|
Choices: []ChatCompletionChoice{
|
|
{
|
|
Message: Message{
|
|
Content: m.MockedContent,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
return &response, nil
|
|
//return nil, errors.New("something went wrong")
|
|
}
|
|
|
|
func (m MistralClientMock) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
|
ersteller.Debug("mocking the mistral client embedding")
|
|
// Return mock embeddings with dummy values
|
|
data := make([]EmbeddingData, len(req.Input))
|
|
for i := range req.Input {
|
|
data[i] = EmbeddingData{
|
|
Object: "embedding",
|
|
Embedding: make([]float64, 1024), // Mistral embedding models typically use 1024 dimensions
|
|
Index: i,
|
|
}
|
|
}
|
|
response := EmbeddingResponse{
|
|
ID: "mock-embedding-id",
|
|
Object: "list",
|
|
Data: data,
|
|
Model: req.Model,
|
|
Usage: EmbeddingUsageInfo{
|
|
PromptTokens: len(req.Input) * 10,
|
|
TotalTokens: len(req.Input) * 10,
|
|
},
|
|
}
|
|
return &response, nil
|
|
}
|