diff --git a/llm/mistral.go b/llm/mistral.go new file mode 100644 index 0000000..24dcde1 --- /dev/null +++ b/llm/mistral.go @@ -0,0 +1,159 @@ +package llm + +import ( + "bytes" + "encoding/json" + "ersteller-lib" + "fmt" + "io" + "net/http" +) + +type ChatCompletionRequestResponseFormatType string + +const ( + ChatCompletionRequestResponseFormatText ChatCompletionRequestResponseFormatType = "text" + ChatCompletionRequestResponseFormatJSON ChatCompletionRequestResponseFormatType = "json_object" +) + +type ChatCompletionRequestResponseFormat struct { + Type ChatCompletionRequestResponseFormatType `json:"type,omitempty"` +} + +// ChatCompletionRequest represents the request body for chat completions +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stream bool `json:"stream"` + Stop interface{} `json:"stop,omitempty"` + RandomSeed *int `json:"random_seed,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + N *int `json:"n,omitempty"` + SafePrompt bool `json:"safe_prompt"` + ResponseFormat ChatCompletionRequestResponseFormat `json:"response_format,omitempty"` +} + +// Message represents a chat message +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// Tool represents a function tool +type Tool struct { + Type string `json:"type"` + Function Function `json:"function"` +} + +// Function represents a function definition +type Function struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]interface{} `json:"parameters"` + Strict bool `json:"strict,omitempty"` +} + +// ChatCompletionResponse represents the API response +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage UsageInfo `json:"usage"` +} + +// ChatCompletionChoice represents a completion choice +type ChatCompletionChoice struct { + Index int `json:"index"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// UsageInfo represents token usage information +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Client represents a Mistral API client +type MistralClientImpl struct { + APIKey string + BaseURL string + HTTPClient *http.Client +} + +// NewClient creates a new Mistral API client +func NewMistralClient(apiKey string) *MistralClientImpl { + return &MistralClientImpl{ + APIKey: apiKey, + BaseURL: "https://api.mistral.ai", + HTTPClient: &http.Client{}, + } +} + +// CreateChatCompletion sends a chat completion request to the API +func (c *MistralClientImpl) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { + jsonData, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + request, err := http.NewRequest("POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Authorization", "Bearer "+c.APIKey) + + response, err := c.HTTPClient.Do(request) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", response.StatusCode, string(body)) + } + + var result ChatCompletionResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return &result, nil +} + +type MistralClient interface { + CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) +} + +type MistralClientMock struct{} + +func (m MistralClientMock) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { + ersteller_lib.Debug("mocking the mistral client") + response := ChatCompletionResponse{ + Choices: []ChatCompletionChoice{ + { + Message: Message{ + Content: `{"name":"Meetup app","description":"An app to make notes for meetups","entities":[{"name":"Meetup","fields":[{"name":"name","type":"string"},{"name":"description","type":"string","longText":true},{"name":"date","type":"time"}],"entities":[{"name":"Notes","fields":[{"name":"title","type":"string"},{"name":"description","type":"string","longText":true}]}]}]}`, + }, + }, + }, + } + return &response, nil + //return nil, errors.New("something went wrong") +}