Files
ersteller/llm/mistral.go
T

167 lines
5.7 KiB
Go

package llm
import (
"bytes"
"encoding/json"
"fmt"
"git.gorlug.de/code/golang/ersteller-lib/ersteller"
"io"
"net/http"
)
type ChatCompletionRequestResponseFormatType string
const (
ChatCompletionRequestResponseFormatText ChatCompletionRequestResponseFormatType = "text"
ChatCompletionRequestResponseFormatJSON ChatCompletionRequestResponseFormatType = "json_object"
)
type ChatCompletionRequestResponseFormat struct {
Type ChatCompletionRequestResponseFormatType `json:"type,omitempty"`
}
// ChatCompletionRequest represents the request body for chat completions
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
Stop interface{} `json:"stop,omitempty"`
RandomSeed *int `json:"random_seed,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N *int `json:"n,omitempty"`
SafePrompt bool `json:"safe_prompt"`
ResponseFormat ChatCompletionRequestResponseFormat `json:"response_format,omitempty"`
}
// ContentItem represents a single content item (text or image)
type ContentItem struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL string `json:"image_url,omitempty"`
}
// Message represents a chat message
type Message struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
// Tool represents a function tool
type Tool struct {
Type string `json:"type"`
Function Function `json:"function"`
}
// Function represents a function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters map[string]interface{} `json:"parameters"`
Strict bool `json:"strict,omitempty"`
}
// ChatCompletionResponse represents the API response
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage UsageInfo `json:"usage"`
}
// ChatCompletionChoice represents a completion choice
type ChatCompletionChoice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
// UsageInfo represents token usage information
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Client represents a Mistral API client
type MistralClientImpl struct {
APIKey string
BaseURL string
HTTPClient *http.Client
}
// NewClient creates a new Mistral API client
func NewMistralClient(apiKey string) *MistralClientImpl {
return &MistralClientImpl{
APIKey: apiKey,
BaseURL: "https://api.mistral.ai",
HTTPClient: &http.Client{},
}
}
// CreateChatCompletion sends a chat completion request to the API
func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
jsonData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %v", err)
}
request, err := http.NewRequest("POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Authorization", "Bearer "+c.APIKey)
response, err := c.HTTPClient.Do(request)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %v", err)
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body))
}
var result ChatCompletionResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %v", err)
}
return &result, nil
}
type MistralClient interface {
CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error)
}
type MistralClientMock struct{}
func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
ersteller.Debug("mocking the mistral client")
response := ChatCompletionResponse{
Choices: []ChatCompletionChoice{
{
Message: Message{
Content: `{"name":"Meetup app","description":"An app to make notes for meetups","entities":[{"name":"Meetup","fields":[{"name":"name","type":"string"},{"name":"description","type":"string","longText":true},{"name":"date","type":"time"}],"entities":[{"name":"Notes","fields":[{"name":"title","type":"string"},{"name":"description","type":"string","longText":true}]}]}]}`,
},
},
},
}
return &response, nil
//return nil, errors.New("something went wrong")
}