diff --git a/README.md b/README.md index 89c5024..ab0426b 100644 --- a/README.md +++ b/README.md @@ -65,10 +65,10 @@ For building binary if you wish to build from source, then `cargo` is required. openai = { endpoint = "https://api.openai.com/v1", model = "gpt-4o", -- your desired model (or use gpt-4o, etc.) - timeout = 30000, -- timeout in milliseconds - temperature = 0, -- adjust if needed - max_tokens = 4096, - -- reasoning_effort = "high" -- only supported for reasoning models (o1, etc.) + timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models + temperature = 0, + max_completion_tokens = 8192, -- Increase this to include reasoning tokens (for reasoning models) + --reasoning_effort = "medium", -- low|medium|high, only used for reasoning models }, }, -- if you want to build from source then do `make BUILD_FROM_SOURCE=true` diff --git a/cursor-planning-mode.md b/cursor-planning-mode.md index 4a7d99f..a78d3d9 100644 --- a/cursor-planning-mode.md +++ b/cursor-planning-mode.md @@ -35,7 +35,7 @@ Then enable it in avante.nvim: api_key_name = 'GROQ_API_KEY', endpoint = 'https://api.groq.com/openai/v1/', model = 'llama-3.3-70b-versatile', - max_tokens = 32768, -- remember to increase this value, otherwise it will stop generating halfway + max_completion_tokens = 32768, -- remember to increase this value, otherwise it will stop generating halfway }, }, --- ... existing configurations diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 359dc58..aa43db4 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -194,9 +194,10 @@ M._defaults = { openai = { endpoint = "https://api.openai.com/v1", model = "gpt-4o", - timeout = 30000, -- Timeout in milliseconds + timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models temperature = 0, - max_tokens = 16384, + max_completion_tokens = 16384, -- Increase this to include reasoning tokens (for reasoning models) + reasoning_effort = "medium", -- low|medium|high, only used for reasoning models }, ---@type AvanteSupportedProvider copilot = { @@ -212,10 +213,11 @@ M._defaults = { azure = { endpoint = "", -- example: "https://.openai.azure.com" deployment = "", -- Azure deployment name (e.g., "gpt-4o", "my-gpt-4o-deployment") - api_version = "2024-06-01", - timeout = 30000, -- Timeout in milliseconds + api_version = "2024-12-01-preview", + timeout = 30000, -- Timeout in milliseconds, increase this for reasoning models temperature = 0, - max_tokens = 20480, + max_completion_tokens = 20480, -- Increase this to include reasoning tokens (for reasoning models) + reasoning_effort = "medium", -- low|medium|high, only used for reasoning models }, ---@type AvanteSupportedProvider claude = { diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 477241b..624eddb 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -2,7 +2,8 @@ ---@field deployment string ---@field api_version string ---@field temperature number ----@field max_tokens number +---@field max_completion_tokens number +---@field reasoning_effort? string local Utils = require("avante.utils") local P = require("avante.providers") @@ -13,12 +14,8 @@ local M = {} M.api_key_name = "AZURE_OPENAI_API_KEY" -M.parse_messages = O.parse_messages -M.parse_response = O.parse_response -M.parse_response_without_stream = O.parse_response_without_stream -M.is_disable_stream = O.is_disable_stream -M.is_o_series_model = O.is_o_series_model -M.role_map = O.role_map +-- Inherit from OpenAI class +setmetatable(M, { __index = O }) function M:parse_curl_args(prompt_opts) local provider_conf, request_body = P.parse_config(self) @@ -35,11 +32,8 @@ function M:parse_curl_args(prompt_opts) end end - -- NOTE: When using "o" series set the supported parameters only - if O.is_o_series_model(provider_conf.model) then - request_body.max_tokens = nil - request_body.temperature = 1 - end + -- NOTE: When using reasoning models set supported parameters + self.set_reasoning_params(provider_conf, request_body) return { url = Utils.url_join( diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index acd9941..4ce11ba 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -67,16 +67,23 @@ function M.get_user_message(opts) ) end -function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end +function M.is_reasoning_model(model) return model and string.match(model, "^o%d+") ~= nil end + +function M.set_reasoning_params(provider_conf, request_body) + if M.is_reasoning_model(provider_conf.model) then + request_body.temperature = 1 + request_body.reasoning_effort = request_body.reasoning_effort + else + request_body.reasoning_effort = nil + end +end function M:parse_messages(opts) local messages = {} local provider_conf, _ = P.parse_config(self) - -- NOTE: Handle the case where the selected model is the `o1` model - -- "o1" models are "smart" enough to understand user prompt as a system prompt in this context - if self.is_o_series_model(provider_conf.model) then - table.insert(messages, { role = "user", content = opts.system_prompt }) + if self.is_reasoning_model(provider_conf.model) then + table.insert(messages, { role = "developer", content = opts.system_prompt }) else table.insert(messages, { role = "system", content = opts.system_prompt }) end @@ -298,12 +305,8 @@ function M:parse_curl_args(prompt_opts) request_body.include_reasoning = true end - -- NOTE: When using "o" series set the supported parameters only - if self.is_o_series_model(provider_conf.model) then - request_body.max_completion_tokens = request_body.max_tokens - request_body.max_tokens = nil - request_body.temperature = 1 - end + -- NOTE: When using reasoning models set supported parameters + self.set_reasoning_params(provider_conf, request_body) local tools = nil if not disable_tools and prompt_opts.tools then diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 036e845..6d2ddd9 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -216,6 +216,7 @@ vim.g.avante_login = vim.g.avante_login ---@field __inherited_from? string ---@field temperature? number ---@field max_tokens? number +---@field max_completion_tokens? number ---@field reasoning_effort? string ---@field display_name? string ---