diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index c657c74..064cb10 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -25,8 +25,8 @@ M.generate_prompts = function(opts) local Provider = opts.provider or P[Config.provider] local mode = opts.mode or "planning" ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor - local _, body_opts = P.parse_config(Provider) - local max_tokens = body_opts.max_tokens or 4096 + local _, request_body = P.parse_config(Provider) + local max_tokens = request_body.max_tokens or 4096 -- Check if the instructions contains an image path local image_paths = {} diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index 83ca617..270e1a2 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -18,31 +18,34 @@ M.parse_response = O.parse_response M.parse_response_without_stream = O.parse_response_without_stream M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local headers = { ["Content-Type"] = "application/json", } - if P.env.require_api_key(base) then headers["api-key"] = provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["api-key"] = provider.parse_api_key() end -- NOTE: When using "o" series set the supported parameters only - if O.is_o_series_model(base.model) then - body_opts.max_tokens = nil - body_opts.temperature = 1 + if O.is_o_series_model(provider_conf.model) then + request_body.max_tokens = nil + request_body.temperature = 1 end return { url = Utils.url_join( - base.endpoint, - "/openai/deployments/" .. base.deployment .. "/chat/completions?api-version=" .. base.api_version + provider_conf.endpoint, + "/openai/deployments/" + .. provider_conf.deployment + .. "/chat/completions?api-version=" + .. provider_conf.api_version ), - proxy = base.proxy, - insecure = base.allow_insecure, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { messages = M.parse_messages(prompt_opts), stream = true, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index 49720f3..07ef4f0 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -1,5 +1,4 @@ local Utils = require("avante.utils") -local Clipboard = require("avante.clipboard") local P = require("avante.providers") ---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table): table @@ -17,17 +16,14 @@ M.api_key_name = "BEDROCK_KEYS" M.use_xml_format = true M.load_model_handler = function() - local base, _ = P.parse_config(P["bedrock"]) - local bedrock_model = base.model - if base.model:match("anthropic") then bedrock_model = "claude" end + local provider_conf, _ = P.parse_config(P["bedrock"]) + local bedrock_model = provider_conf.model + if provider_conf.model:match("anthropic") then bedrock_model = "claude" end local ok, model_module = pcall(require, "avante.providers.bedrock." .. bedrock_model) - if ok then - return model_module - else - local error_msg = "Bedrock model handler not found: " .. bedrock_model - Utils.error(error_msg, { once = true, title = "Avante" }) - end + if ok then return model_module end + local error_msg = "Bedrock model handler not found: " .. bedrock_model + error(error_msg) end M.parse_response = function(ctx, data_stream, event_state, opts) @@ -46,8 +42,8 @@ M.parse_stream_data = function(data, opts) -- The `type` field in the decoded JSON determines how the response is handled. local bedrock_match = data:gmatch("event(%b{})") for bedrock_data_match in bedrock_match do - local data = vim.json.decode(bedrock_data_match) - local data_stream = vim.base64.decode(data.bytes) + local jsn = vim.json.decode(bedrock_data_match) + local data_stream = vim.base64.decode(jsn.bytes) local json = vim.json.decode(data_stream) M.parse_response({}, data_stream, json.type, opts) end @@ -60,6 +56,7 @@ M.parse_curl_args = function(provider, prompt_opts) local base, body_opts = P.parse_config(provider) local api_key = provider.parse_api_key() + if api_key == nil then error("Cannot get the bedrock api key!") end local parts = vim.split(api_key, ",") local aws_access_key_id = parts[1] local aws_secret_access_key = parts[2] @@ -108,7 +105,6 @@ M.on_error = function(result) end local error_msg = body.error.message - local error_type = body.error.type Utils.error(error_msg, { once = true, title = "Avante" }) end diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index ecc97dd..bf64fa3 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -226,7 +226,7 @@ end ---@param prompt_opts AvantePromptOptions ---@return table M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local headers = { ["Content-Type"] = "application/json", @@ -234,7 +234,7 @@ M.parse_curl_args = function(provider, prompt_opts) ["anthropic-beta"] = "prompt-caching-2024-07-31", } - if P.env.require_api_key(base) then headers["x-api-key"] = provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end local messages = M.parse_messages(prompt_opts) @@ -246,12 +246,12 @@ M.parse_curl_args = function(provider, prompt_opts) end return { - url = Utils.url_join(base.endpoint, "/v1/messages"), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join(provider_conf.endpoint, "/v1/messages"), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, system = { { type = "text", @@ -262,7 +262,7 @@ M.parse_curl_args = function(provider, prompt_opts) messages = messages, tools = tools, stream = true, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index 059e23e..f6c7953 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -70,7 +70,7 @@ M.parse_stream_data = function(data, opts) end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local headers = { ["Accept"] = "application/json", @@ -82,17 +82,17 @@ M.parse_curl_args = function(provider, prompt_opts) .. "." .. vim.version().patch, } - if P.env.require_api_key(base) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end + if P.env.require_api_key(provider_conf) then headers["Authorization"] = "Bearer " .. provider.parse_api_key() end return { - url = Utils.url_join(base.endpoint, "/chat"), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join(provider_conf.endpoint, "/chat"), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, stream = true, - }, M.parse_messages(prompt_opts), body_opts), + }, M.parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index 47ec4b4..050b934 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -249,7 +249,7 @@ M.parse_curl_args = function(provider, prompt_opts) -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local tools = {} if prompt_opts.tools then @@ -259,10 +259,10 @@ M.parse_curl_args = function(provider, prompt_opts) end return { - url = H.chat_completion_url(base.endpoint), - timeout = base.timeout, - proxy = base.proxy, - insecure = base.allow_insecure, + url = H.chat_completion_url(provider_conf.endpoint), + timeout = provider_conf.timeout, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = { ["Content-Type"] = "application/json", ["Authorization"] = "Bearer " .. M.state.github_token.token, @@ -270,11 +270,11 @@ M.parse_curl_args = function(provider, prompt_opts) ["Editor-Version"] = ("Neovim/%s.%s.%s"):format(vim.version().major, vim.version().minor, vim.version().patch), }, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, messages = M.parse_messages(prompt_opts), stream = true, tools = tools, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 1528b87..9a13346 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -82,26 +82,29 @@ M.parse_response = function(ctx, data_stream, _, opts) end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) - body_opts = vim.tbl_deep_extend("force", body_opts, { + request_body = vim.tbl_deep_extend("force", request_body, { generationConfig = { - temperature = body_opts.temperature, - maxOutputTokens = body_opts.max_tokens, + temperature = request_body.temperature, + maxOutputTokens = request_body.max_tokens, }, }) - body_opts.temperature = nil - body_opts.max_tokens = nil + request_body.temperature = nil + request_body.max_tokens = nil local api_key = provider.parse_api_key() if api_key == nil then error("Cannot get the gemini api key!") end return { - url = Utils.url_join(base.endpoint, base.model .. ":streamGenerateContent?alt=sse&key=" .. api_key), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join( + provider_conf.endpoint, + provider_conf.model .. ":streamGenerateContent?alt=sse&key=" .. api_key + ), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = { ["Content-Type"] = "application/json" }, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), + body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index fdb1cd3..3fa7915 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -346,9 +346,9 @@ M = setmetatable(M, { if t[k].has == nil then t[k].has = function() return E.parse_envvar(t[k]) ~= nil end end if t[k].setup == nil then - local base = M.parse_config(t[k]) + local provider_conf = M.parse_config(t[k]) t[k].setup = function() - if E.require_api_key(base) then t[k].parse_api_key() end + if E.require_api_key(provider_conf) then t[k].parse_api_key() end require("avante.tokenizers").setup(t[k].tokenizer_id) end end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 0543046..d03541b 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -275,14 +275,14 @@ M.parse_response_without_stream = function(data, _, opts) end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) - local disable_tools = base.disable_tools or false + local provider_conf, request_body = P.parse_config(provider) + local disable_tools = provider_conf.disable_tools or false local headers = { ["Content-Type"] = "application/json", } - if P.env.require_api_key(base) then + if P.env.require_api_key(provider_conf) then local api_key = provider.parse_api_key() if api_key == nil then error(Config.provider .. " API key is not set, please set it in your environment variable or config file") @@ -290,18 +290,18 @@ M.parse_curl_args = function(provider, prompt_opts) headers["Authorization"] = "Bearer " .. api_key end - if M.is_openrouter(base.endpoint) then + if M.is_openrouter(provider_conf.endpoint) then headers["HTTP-Referer"] = "https://github.com/yetone/avante.nvim" headers["X-Title"] = "Avante.nvim" - body_opts.include_reasoning = true + request_body.include_reasoning = true end -- NOTE: When using "o" series set the supported parameters only local stream = true - if M.is_o_series_model(base.model) then - body_opts.max_completion_tokens = body_opts.max_tokens - body_opts.max_tokens = nil - body_opts.temperature = 1 + if M.is_o_series_model(provider_conf.model) then + request_body.max_completion_tokens = request_body.max_tokens + request_body.max_tokens = nil + request_body.temperature = 1 end local tools = nil @@ -312,20 +312,20 @@ M.parse_curl_args = function(provider, prompt_opts) end end - Utils.debug("endpoint", base.endpoint) - Utils.debug("model", base.model) + Utils.debug("endpoint", provider_conf.endpoint) + Utils.debug("model", provider_conf.model) return { - url = Utils.url_join(base.endpoint, "/chat/completions"), - proxy = base.proxy, - insecure = base.allow_insecure, + url = Utils.url_join(provider_conf.endpoint, "/chat/completions"), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - model = base.model, + model = provider_conf.model, messages = M.parse_messages(prompt_opts), stream = stream, tools = tools, - }, body_opts), + }, request_body), } end diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index f1ca9ee..5273483 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -32,22 +32,22 @@ M.parse_api_key = function() end M.parse_curl_args = function(provider, prompt_opts) - local base, body_opts = P.parse_config(provider) + local provider_conf, request_body = P.parse_config(provider) local location = vim.fn.getenv("LOCATION") or "default-location" local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" - local model_id = base.model or "default-model-id" - local url = base.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) + local model_id = provider_conf.model or "default-model-id" + local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id) - body_opts = vim.tbl_deep_extend("force", body_opts, { + request_body = vim.tbl_deep_extend("force", request_body, { generationConfig = { - temperature = body_opts.temperature, - maxOutputTokens = body_opts.max_tokens, + temperature = request_body.temperature, + maxOutputTokens = request_body.max_tokens, }, }) - body_opts.temperature = nil - body_opts.max_tokens = nil + request_body.temperature = nil + request_body.max_tokens = nil local bearer_token = M.parse_api_key() return { @@ -56,9 +56,9 @@ M.parse_curl_args = function(provider, prompt_opts) ["Authorization"] = "Bearer " .. bearer_token, ["Content-Type"] = "application/json; charset=utf-8", }, - proxy = base.proxy, - insecure = base.allow_insecure, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), body_opts), + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, + body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body), } end