diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index f7a04fe..bc8f3b8 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -1,5 +1,5 @@ local Utils = require("avante.utils") -local P = require("avante.providers") +local Providers = require("avante.providers") ---@class AvanteBedrockProviderFunctor local M = {} @@ -42,7 +42,7 @@ function M.setup() end function M.load_model_handler() - local provider_conf, _ = P.parse_config(P["bedrock"]) + local provider_conf, _ = Providers.parse_config(Providers["bedrock"]) local bedrock_model = provider_conf.model if provider_conf.model:match("anthropic") then bedrock_model = "claude" end @@ -99,7 +99,7 @@ end ---@param prompt_opts AvantePromptOptions ---@return table function M:parse_curl_args(prompt_opts) - local provider_conf, request_body = P.parse_config(self) + local provider_conf, request_body = Providers.parse_config(self) local access_key_id, secret_access_key, session_token, region diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index d27ec3d..23fbd42 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -1,4 +1,4 @@ -local api, fn = vim.api, vim.fn +local api = vim.api local Config = require("avante.config") local Utils = require("avante.utils") @@ -136,28 +136,34 @@ M = setmetatable(M, { __index = function(t, k) if Config.providers[k] == nil then error("Failed to find provider: " .. k, 2) end - local provider_config = M.get_config(k) + t[k] = setmetatable({}, { + __index = function(_, k_) + local provider_config = M.get_config(k) - if provider_config.__inherited_from ~= nil then - local base_provider_config = M.get_config(provider_config.__inherited_from) - local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from) - if not ok then error("Failed to load provider: " .. provider_config.__inherited_from, 2) end - provider_config = Utils.deep_extend_with_metatable("force", module, base_provider_config, provider_config) - else - local ok, module = pcall(require, "avante.providers." .. k) - if ok then - provider_config = Utils.deep_extend_with_metatable("force", module, provider_config) - elseif provider_config.parse_curl_args == nil then - error( - string.format( - 'The configuration of your provider "%s" is incorrect, missing the `__inherited_from` attribute or a custom `parse_curl_args` function. Please fix your provider configuration. For more details, see: https://github.com/yetone/avante.nvim/wiki/Custom-providers', - k - ) - ) - end - end + if provider_config.__inherited_from ~= nil then + local base_provider_config = M.get_config(provider_config.__inherited_from) + local ok, module = pcall(require, "avante.providers." .. provider_config.__inherited_from) + if not ok then error("Failed to load provider: " .. provider_config.__inherited_from, 2) end + -- provider_config = Utils.deep_extend_with_metatable("force", module, base_provider_config, provider_config) + provider_config = Utils.inherit({}, provider_config, base_provider_config, module) + else + local ok, module = pcall(require, "avante.providers." .. k) + if ok then + -- provider_config = Utils.deep_extend_with_metatable("force", module, provider_config) + provider_config = Utils.inherit({}, provider_config, module) + elseif provider_config.parse_curl_args == nil then + error( + string.format( + 'The configuration of your provider "%s" is incorrect, missing the `__inherited_from` attribute or a custom `parse_curl_args` function. Please fix your provider configuration. For more details, see: https://github.com/yetone/avante.nvim/wiki/Custom-providers', + k_ + ) + ) + end + end - t[k] = provider_config + return provider_config[k_] + end, + }) if t[k].parse_api_key == nil then t[k].parse_api_key = function() return E.parse_envvar(t[k]) end end @@ -221,17 +227,10 @@ end ---@return AvanteDefaultBaseProvider provider_opts ---@return table request_body function M.parse_config(opts) - ---@type AvanteDefaultBaseProvider - local provider_opts = {} - - for key, value in pairs(opts) do - if key ~= "extra_request_body" then provider_opts[key] = value end - end - ---@type table local request_body = opts.extra_request_body or {} - return provider_opts, request_body + return Utils.inherit({}, opts), request_body end ---@param provider_name avante.ProviderName diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index e0ae1ed..a1a8a59 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1133,6 +1133,17 @@ function M.icon(string_with_icon, utf8_fallback) end end +function M.inherit(base, ...) + local children = { ... } + return setmetatable(base, { + __index = function(_, k) + for _, child in ipairs(children) do + if child[k] ~= nil then return child[k] end + end + end, + }) +end + function M.deep_extend_with_metatable(behavior, ...) local tables = { ... } local base = tables[1]