From a3e5053d552032214b0f117180f87266b1f2bc4e Mon Sep 17 00:00:00 2001 From: yetone Date: Sat, 16 Nov 2024 02:09:14 +0800 Subject: [PATCH] fix: preset vendors missing many fields (#851) --- lua/avante/config.lua | 7 ++++--- lua/avante/providers/init.lua | 11 ++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 0b679fe..59bddee 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -8,7 +8,7 @@ local M = {} ---@class avante.Config M.defaults = { debug = false, - ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | [string] + ---@alias Provider "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | string provider = "claude", -- Only recommend using Claude auto_suggestions_provider = "claude", ---@alias Tokenizer "tiktoken" | "hf" @@ -88,7 +88,7 @@ M.defaults = { vendors = { ---@type AvanteSupportedProvider ["claude-haiku"] = { - endpoint = "https://api.anthropic.com", + __inherited_from = "claude", model = "claude-3-5-haiku-20241022", timeout = 30000, -- Timeout in milliseconds temperature = 0, @@ -97,7 +97,7 @@ M.defaults = { }, ---@type AvanteSupportedProvider ["claude-opus"] = { - endpoint = "https://api.anthropic.com", + __inherited_from = "claude", model = "claude-3-opus-20240229", timeout = 30000, -- Timeout in milliseconds temperature = 0, @@ -340,6 +340,7 @@ M.BASE_PROVIDER_KEYS = { "tokenizer_id", "use_xml_format", "role_map", + "__inherited_from", } return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 6582648..9b55e9d 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -50,6 +50,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil } ---@field _shellenv? string --- ---@class AvanteSupportedProvider: AvanteDefaultBaseProvider +---@field __inherited_from? string ---@field temperature? number ---@field max_tokens? number --- @@ -267,7 +268,15 @@ M = setmetatable(M, { ---@diagnostic disable: undefined-field,no-unknown,inject-field if Config.vendors[k] ~= nil then Opts.parse_response = Opts.parse_response_data - t[k] = Opts + if Opts.__inherited_from ~= nil then + local BaseOpts = M.get_config(Opts.__inherited_from) + local ok, module = pcall(require, "avante.providers." .. Opts.__inherited_from) + if not ok then error("Failed to load provider: " .. Opts.__inherited_from) end + Opts._shellenv = module.api_key_name ~= M.AVANTE_INTERNAL_KEY and module.api_key_name or nil + t[k] = vim.tbl_deep_extend("keep", BaseOpts, Opts, module) + else + t[k] = Opts + end else local ok, module = pcall(require, "avante.providers." .. k) if not ok then error("Failed to load provider: " .. k) end