feat: update openai/azure params (#1604)

* feat(openai): use max_completion_tokens & reasoning_effort params

* feat(openai): use developer prompt for reasoning models

* docs: update openai config in readme

* refactor: follow lua style quotes

* fix(azure): rename max_tokens to max_completion_tokens

* refactor(azure): remove duplicate field

* refactor: update types

* refactor(azure): update type
This commit is contained in:
kernitus
2025-03-18 11:40:20 +00:00
committed by GitHub
parent 62a8c07e91
commit 10ce065d9e
6 changed files with 33 additions and 33 deletions

View File

@@ -2,7 +2,8 @@
---@field deployment string
---@field api_version string
---@field temperature number
---@field max_tokens number
---@field max_completion_tokens number
---@field reasoning_effort? string
local Utils = require("avante.utils")
local P = require("avante.providers")
@@ -13,12 +14,8 @@ local M = {}
M.api_key_name = "AZURE_OPENAI_API_KEY"
M.parse_messages = O.parse_messages
M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream
M.is_disable_stream = O.is_disable_stream
M.is_o_series_model = O.is_o_series_model
M.role_map = O.role_map
-- Inherit from OpenAI class
setmetatable(M, { __index = O })
function M:parse_curl_args(prompt_opts)
local provider_conf, request_body = P.parse_config(self)
@@ -35,11 +32,8 @@ function M:parse_curl_args(prompt_opts)
end
end
-- NOTE: When using "o" series set the supported parameters only
if O.is_o_series_model(provider_conf.model) then
request_body.max_tokens = nil
request_body.temperature = 1
end
-- NOTE: When using reasoning models set supported parameters
self.set_reasoning_params(provider_conf, request_body)
return {
url = Utils.url_join(

View File

@@ -67,16 +67,23 @@ function M.get_user_message(opts)
)
end
function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end
function M.is_reasoning_model(model) return model and string.match(model, "^o%d+") ~= nil end
function M.set_reasoning_params(provider_conf, request_body)
if M.is_reasoning_model(provider_conf.model) then
request_body.temperature = 1
request_body.reasoning_effort = request_body.reasoning_effort
else
request_body.reasoning_effort = nil
end
end
function M:parse_messages(opts)
local messages = {}
local provider_conf, _ = P.parse_config(self)
-- NOTE: Handle the case where the selected model is the `o1` model
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
if self.is_o_series_model(provider_conf.model) then
table.insert(messages, { role = "user", content = opts.system_prompt })
if self.is_reasoning_model(provider_conf.model) then
table.insert(messages, { role = "developer", content = opts.system_prompt })
else
table.insert(messages, { role = "system", content = opts.system_prompt })
end
@@ -298,12 +305,8 @@ function M:parse_curl_args(prompt_opts)
request_body.include_reasoning = true
end
-- NOTE: When using "o" series set the supported parameters only
if self.is_o_series_model(provider_conf.model) then
request_body.max_completion_tokens = request_body.max_tokens
request_body.max_tokens = nil
request_body.temperature = 1
end
-- NOTE: When using reasoning models set supported parameters
self.set_reasoning_params(provider_conf, request_body)
local tools = nil
if not disable_tools and prompt_opts.tools then