package llm import ( "bytes" "encoding/json" "fmt" "io" "net/http" "git.gorlug.de/code/ersteller" ) type ChatCompletionRequestResponseFormatType string const ( ChatCompletionRequestResponseFormatText ChatCompletionRequestResponseFormatType = "text" ChatCompletionRequestResponseFormatJSON ChatCompletionRequestResponseFormatType = "json_object" ) type ChatCompletionRequestResponseFormat struct { Type ChatCompletionRequestResponseFormatType `json:"type,omitempty"` } // ChatCompletionRequest represents the request body for chat completions type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"` Stream bool `json:"stream"` Stop interface{} `json:"stop,omitempty"` RandomSeed *int `json:"random_seed,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice interface{} `json:"tool_choice,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` N *int `json:"n,omitempty"` SafePrompt bool `json:"safe_prompt"` ResponseFormat ChatCompletionRequestResponseFormat `json:"response_format,omitempty"` } // ContentItem represents a single content item (text or image) type ContentItem struct { Type string `json:"type"` Text string `json:"text,omitempty"` ImageURL string `json:"image_url,omitempty"` } // Message represents a chat message type Message struct { Role string `json:"role"` Content interface{} `json:"content"` } // Tool represents a function tool type Tool struct { Type string `json:"type"` Function Function `json:"function"` } // Function represents a function definition type Function struct { Name string `json:"name"` Description string `json:"description,omitempty"` Parameters map[string]interface{} `json:"parameters"` Strict bool `json:"strict,omitempty"` } // ChatCompletionResponse represents the API response type ChatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage UsageInfo `json:"usage"` } // ChatCompletionChoice represents a completion choice type ChatCompletionChoice struct { Index int `json:"index"` Message Message `json:"message"` FinishReason string `json:"finish_reason"` } // UsageInfo represents token usage information type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } // EmbeddingRequest represents the request body for embeddings type EmbeddingRequest struct { Model string `json:"model"` Input []string `json:"input"` EncodingFormat string `json:"encoding_format,omitempty"` } // EmbeddingResponse represents the API response for embeddings type EmbeddingResponse struct { ID string `json:"id"` Object string `json:"object"` Data []EmbeddingData `json:"data"` Model string `json:"model"` Usage EmbeddingUsageInfo `json:"usage"` } // EmbeddingData represents a single embedding result type EmbeddingData struct { Object string `json:"object"` Embedding []float64 `json:"embedding"` Index int `json:"index"` } // EmbeddingUsageInfo represents token usage information for embeddings type EmbeddingUsageInfo struct { PromptTokens int `json:"prompt_tokens"` TotalTokens int `json:"total_tokens"` } // Client represents a Mistral API client type MistralClientImpl struct { APIKey string BaseURL string HTTPClient *http.Client } type MistralClientOption = func(*MistralClientImpl) func BaseUrlOption(url string) MistralClientOption { return func(c *MistralClientImpl) { c.BaseURL = url } } func NewMistralClient(apiKey string, options ...MistralClientOption) *MistralClientImpl { client := &MistralClientImpl{ APIKey: apiKey, BaseURL: "https://api.mistral.ai", HTTPClient: &http.Client{}, } for _, option := range options { option(client) } return client } // CreateChatCompletion sends a chat completion request to the API func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { jsonData, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("error marshaling request: %v", err) } request, err := http.NewRequest("POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("error creating request: %v", err) } request.Header.Set("Content-Type", "application/json") request.Header.Set("Authorization", "Bearer "+c.APIKey) response, err := c.HTTPClient.Do(request) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { return nil, fmt.Errorf("error reading response body: %v", err) } if response.StatusCode != http.StatusOK { return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body)) } var result ChatCompletionResponse if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("error unmarshaling response: %v", err) } return &result, nil } // CreateEmbedding sends an embedding request to the API func (c *MistralClientImpl) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) { jsonData, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("error marshaling request: %v", err) } request, err := http.NewRequest("POST", c.BaseURL+"/v1/embeddings", bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("error creating request: %v", err) } request.Header.Set("Content-Type", "application/json") request.Header.Set("Authorization", "Bearer "+c.APIKey) response, err := c.HTTPClient.Do(request) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { return nil, fmt.Errorf("error reading response body: %v", err) } if response.StatusCode != http.StatusOK { return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body)) } var result EmbeddingResponse if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("error unmarshaling response: %v", err) } return &result, nil } type MistralClient interface { CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) } type MistralClientMock struct { MockedContent string } func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { ersteller.Debug("mocking the mistral client") if m.MockedContent == "" { m.MockedContent = `{"name":"Meetup app","description":"An app to make notes for meetups","entities":[{"name":"Meetup","fields":[{"name":"name","type":"string"},{"name":"description","type":"string","longText":true},{"name":"date","type":"time"}],"entities":[{"name":"Notes","fields":[{"name":"title","type":"string"},{"name":"description","type":"string","longText":true}]}]}]}` } response := ChatCompletionResponse{ Choices: []ChatCompletionChoice{ { Message: Message{ Content: m.MockedContent, }, }, }, } return &response, nil //return nil, errors.New("something went wrong") } func (m MistralClientMock) CreateEmbedding(req *EmbeddingRequest) (*EmbeddingResponse, error) { ersteller.Debug("mocking the mistral client embedding") // Return mock embeddings with dummy values data := make([]EmbeddingData, len(req.Input)) for i := range req.Input { data[i] = EmbeddingData{ Object: "embedding", Embedding: make([]float64, 1024), // Mistral embedding models typically use 1024 dimensions Index: i, } } response := EmbeddingResponse{ ID: "mock-embedding-id", Object: "list", Data: data, Model: req.Model, Usage: EmbeddingUsageInfo{ PromptTokens: len(req.Input) * 10, TotalTokens: len(req.Input) * 10, }, } return &response, nil }