diff --git a/lua/avante/config.lua b/lua/avante/config.lua index a1f8c26..08b2f69 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -20,7 +20,7 @@ local M = {} ---@field custom_tools AvanteLLMToolPublic[] M._defaults = { debug = false, - ---@alias ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | string + ---@alias ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | string provider = "claude", -- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive, -- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048 @@ -225,7 +225,7 @@ M._defaults = { }, ---@type AvanteSupportedProvider bedrock = { - model = "anthropic.claude-3-5-sonnet-20240620-v1:0", + model = "anthropic.claude-3-5-sonnet-20241022-v2:0", timeout = 30000, -- Timeout in milliseconds temperature = 0, max_tokens = 8000, @@ -474,7 +474,7 @@ function M.setup(opts) M._options = merged M.provider_names = vim .iter(M._defaults) - :filter(function(_, value) return type(value) == "table" and value.endpoint ~= nil end) + :filter(function(_, value) return type(value) == "table" and (value.endpoint ~= nil or value.model ~= nil) end) :fold({}, function(acc, k) acc = vim.list_extend({}, acc) acc = vim.list_extend(acc, { k }) @@ -519,12 +519,12 @@ function M.get_window_width() return math.ceil(vim.o.columns * (M.windows.width ---@param provider_name ProviderName ---@return boolean -function M.has_provider(provider_name) return M._options[provider_name] ~= nil or M.vendors[provider_name] ~= nil end +function M.has_provider(provider_name) return vim.list_contains(M.provider_names, provider_name) end ---get supported providers ---@param provider_name ProviderName ----@return AvanteProviderFunctor -function M.get_provider(provider_name) +function M.get_provider_config(provider_name) + if not M.has_provider(provider_name) then error("No provider found: " .. provider_name, 2) end if M._options[provider_name] ~= nil then return vim.deepcopy(M._options[provider_name], true) elseif M.vendors and M.vendors[provider_name] ~= nil then diff --git a/lua/avante/model_selector.lua b/lua/avante/model_selector.lua index dad319d..2b1616f 100644 --- a/lua/avante/model_selector.lua +++ b/lua/avante/model_selector.lua @@ -21,7 +21,7 @@ function M.open() -- Collect models from main providers and vendors for _, provider_name in ipairs(Config.provider_names) do - local entry = create_model_entry(provider_name, Config.get_provider(provider_name)) + local entry = create_model_entry(provider_name, Config.get_provider_config(provider_name)) if entry then table.insert(models, entry) end end @@ -43,7 +43,7 @@ function M.open() Config.override({ [choice.provider_name] = vim.tbl_deep_extend( "force", - Config.get_provider(choice.provider_name), + Config.get_provider_config(choice.provider_name), { model = choice.model } ), }) diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index 9bc1445..1712335 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -35,9 +35,9 @@ function M:parse_response(ctx, data_stream, event_state, opts) return model_handler.parse_response(self, ctx, data_stream, event_state, opts) end -function M:build_bedrock_payload(prompt_opts, body_opts) +function M:build_bedrock_payload(prompt_opts, request_body) local model_handler = M.load_model_handler() - return model_handler.build_bedrock_payload(self, prompt_opts, body_opts) + return model_handler.build_bedrock_payload(self, prompt_opts, request_body) end function M:parse_stream_data(ctx, data, opts) @@ -66,7 +66,7 @@ end ---@param prompt_opts AvantePromptOptions ---@return table function M:parse_curl_args(prompt_opts) - local base, body_opts = P.parse_config(self) + local provider_conf, request_body = P.parse_config(self) local api_key = self.parse_api_key() if api_key == nil then error("Cannot get the bedrock api key!") end @@ -79,7 +79,7 @@ function M:parse_curl_args(prompt_opts) local endpoint = string.format( "https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", aws_region, - base.model + provider_conf.model ) local headers = { @@ -88,7 +88,7 @@ function M:parse_curl_args(prompt_opts) if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end - local body_payload = self:build_bedrock_payload(prompt_opts, body_opts) + local body_payload = self:build_bedrock_payload(prompt_opts, request_body) local rawArgs = { "--aws-sigv4", @@ -99,8 +99,8 @@ function M:parse_curl_args(prompt_opts) return { url = endpoint, - proxy = base.proxy, - insecure = base.allow_insecure, + proxy = provider_conf.proxy, + insecure = provider_conf.allow_insecure, headers = headers, body = body_payload, rawArgs = rawArgs, diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua index a714821..e5a16a5 100644 --- a/lua/avante/providers/bedrock/claude.lua +++ b/lua/avante/providers/bedrock/claude.lua @@ -23,19 +23,19 @@ M.parse_response = Claude.parse_response ---@param provider AvanteProviderFunctor ---@param prompt_opts AvantePromptOptions ----@param body_opts table +---@param request_body table ---@return table -function M.build_bedrock_payload(provider, prompt_opts, body_opts) +function M.build_bedrock_payload(provider, prompt_opts, request_body) local system_prompt = prompt_opts.system_prompt or "" local messages = provider:parse_messages(prompt_opts) - local max_tokens = body_opts.max_tokens or 2000 + local max_tokens = request_body.max_tokens or 2000 local payload = { anthropic_version = "bedrock-2023-05-31", max_tokens = max_tokens, messages = messages, system = system_prompt, } - return vim.tbl_deep_extend("force", payload, body_opts or {}) + return vim.tbl_deep_extend("force", payload, request_body or {}) end return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 21d6494..d941d56 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -198,26 +198,25 @@ M = setmetatable(M, { ---@param t avante.Providers ---@param k ProviderName __index = function(t, k) - ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor - local Opts = M.get_config(k) + local provider_config = M.get_config(k) ---@diagnostic disable: undefined-field,no-unknown,inject-field if Config.vendors[k] ~= nil then - if Opts.parse_response_data ~= nil then + if provider_config.parse_response_data ~= nil then Utils.error("parse_response_data is not supported for avante.nvim vendors") end - if Opts.__inherited_from ~= nil then - local BaseOpts = M.get_config(Opts.__inherited_from) - local ok, module = pcall(require, "avante.providers." .. Opts.__inherited_from) - if not ok then error("Failed to load provider: " .. Opts.__inherited_from) end - t[k] = vim.tbl_deep_extend("keep", Opts, BaseOpts, module) + if provider_config.__inherited_from ~= nil then + local base_provider_config = M.get_config(provider_config.__inherited_from) + local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from) + if not ok then error("Failed to load provider: " .. provider_config.__inherited_from) end + t[k] = Utils.deep_extend_with_metatable("keep", provider_config, base_provider_config, module) else - t[k] = Opts + t[k] = provider_config end else local ok, module = pcall(require, "avante.providers." .. k) if not ok then error("Failed to load provider: " .. k) end - t[k] = vim.tbl_deep_extend("keep", Opts, module) + t[k] = Utils.deep_extend_with_metatable("keep", provider_config, module) end t[k].parse_api_key = function() return E.parse_envvar(t[k]) end @@ -310,10 +309,9 @@ end ---@private ---@param provider_name ProviderName ----@return AvanteProviderFunctor | AvanteBedrockProviderFunctor function M.get_config(provider_name) provider_name = provider_name or Config.provider - local cur = Config.get_provider(provider_name) + local cur = Config.get_provider_config(provider_name) return type(cur) == "function" and cur() or cur end diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 3b99756..c8ac9ec 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2177,7 +2177,7 @@ function Sidebar:reset_memory(args, cb) table.insert(chat_history, { timestamp = get_timestamp(), provider = Config.provider, - model = Config.get_provider(Config.provider).model, + model = Config.get_provider_config(Config.provider).model, request = "", response = "", original_response = "", @@ -2414,7 +2414,8 @@ function Sidebar:create_input_container(opts) ---@param request string local function handle_submit(request) - local model = Config.has_provider(Config.provider) and Config.get_provider(Config.provider).model or "default" + local model = Config.has_provider(Config.provider) and Config.get_provider_config(Config.provider).model + or "default" local timestamp = get_timestamp() diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 97724d8..90d400a 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -272,7 +272,7 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil --- ----@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, body_opts: table): table +---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table): table --- ---@class AvanteBedrockProviderFunctor: AvanteProviderFunctor ---@field load_model_handler fun(): AvanteBedrockModelHandler diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 93b34b4..0d93c01 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1015,4 +1015,17 @@ function M.icon(string_with_icon, utf8_fallback) end end +function M.deep_extend_with_metatable(behavior, ...) + local tables = { ... } + local base = tables[1] + if behavior == "keep" then base = tables[#tables] end + local mt = getmetatable(base) + + local result = vim.tbl_deep_extend(behavior, ...) + + if mt then setmetatable(result, mt) end + + return result +end + return M