feat(llm): add support for parsing secret vault (#200)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -52,35 +52,24 @@ local Dressing = require("avante.ui.dressing")
|
||||
---@field temperature? number
|
||||
---@field max_tokens? number
|
||||
---
|
||||
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
|
||||
---@field deployment string
|
||||
---@field api_version string
|
||||
---@field temperature number
|
||||
---@field max_tokens number
|
||||
---
|
||||
---@class AvanteCopilotProvider: AvanteSupportedProvider
|
||||
---@field timeout number
|
||||
---
|
||||
---@class AvanteGeminiProvider: AvanteDefaultBaseProvider
|
||||
---@field model string
|
||||
---
|
||||
---@class AvanteProvider: AvanteDefaultBaseProvider
|
||||
---@field parse_response_data AvanteResponseParser
|
||||
---@field parse_curl_args AvanteCurlArgsParser
|
||||
---@field parse_stream_data? AvanteStreamParser
|
||||
---
|
||||
---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil
|
||||
---@alias AvanteChunkParser fun(chunk: string): any
|
||||
---@alias AvanteCompleteParser fun(err: string|nil): nil
|
||||
---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table<string, any>
|
||||
---
|
||||
---@class AvanteProvider: AvanteSupportedProvider
|
||||
---@field parse_response_data AvanteResponseParser
|
||||
---@field parse_curl_args? AvanteCurlArgsParser
|
||||
---@field parse_stream_data? AvanteStreamParser
|
||||
---
|
||||
---@class AvanteProviderFunctor
|
||||
---@field parse_message AvanteMessageParser
|
||||
---@field parse_response AvanteResponseParser
|
||||
---@field parse_curl_args AvanteCurlArgsParser
|
||||
---@field setup? fun(): nil
|
||||
---@field setup fun(): nil
|
||||
---@field has fun(): boolean
|
||||
---@field api_key_name string
|
||||
---@field parse_api_key fun(): string | nil
|
||||
---@field parse_stream_data? AvanteStreamParser
|
||||
---
|
||||
---@class avante.Providers
|
||||
@@ -92,38 +81,50 @@ local Dressing = require("avante.ui.dressing")
|
||||
---@field cohere AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
setmetatable(M, {
|
||||
---@param t avante.Providers
|
||||
---@param k Provider
|
||||
__index = function(t, k)
|
||||
if Config.vendors[k] ~= nil then
|
||||
---@type AvanteProvider
|
||||
local v = Config.vendors[k]
|
||||
|
||||
-- Patch from vendors similar to supported providers.
|
||||
---@type AvanteProviderFunctor
|
||||
t[k] = setmetatable({}, { __index = v })
|
||||
-- Hack for aliasing and makes it sane for us.
|
||||
t[k].parse_response = v.parse_response_data
|
||||
t[k].has = function()
|
||||
return os.getenv(t[k].api_key_name) and true or false
|
||||
end
|
||||
|
||||
return t[k]
|
||||
end
|
||||
|
||||
---@type AvanteProviderFunctor
|
||||
t[k] = require("avante.providers." .. k)
|
||||
return t[k]
|
||||
end,
|
||||
})
|
||||
|
||||
---@class EnvironmentHandler
|
||||
local E = {}
|
||||
|
||||
---@private
|
||||
E._once = false
|
||||
|
||||
---@private
|
||||
---@type table<string, string>
|
||||
E.cache = {}
|
||||
|
||||
---@param Opts AvanteSupportedProvider | AvanteProviderFunctor
|
||||
---@return string | nil
|
||||
E.parse_envvar = function(Opts)
|
||||
local api_key_name = Opts.api_key_name
|
||||
if api_key_name == nil then
|
||||
error("Requires api_key_name")
|
||||
end
|
||||
|
||||
if E.cache[api_key_name] ~= nil then
|
||||
return E.cache[api_key_name]
|
||||
end
|
||||
|
||||
local cmd = api_key_name:match("^cmd:(.*)")
|
||||
|
||||
local key = nil
|
||||
if cmd ~= nil then
|
||||
local ok, job = pcall(vim.system, vim.split(cmd, " ", { trimempty = true }), { text = true })
|
||||
if not ok then
|
||||
Utils.error("Failed to execute command to retrieve secrets: " .. cmd, { once = true, title = "Avante" })
|
||||
else
|
||||
local out = job:wait()
|
||||
key = out.stdout
|
||||
end
|
||||
else
|
||||
key = os.getenv(api_key_name)
|
||||
end
|
||||
|
||||
if key ~= nil then
|
||||
E.cache[api_key_name] = key
|
||||
end
|
||||
|
||||
return key
|
||||
end
|
||||
|
||||
--- initialize the environment variable for current neovim session.
|
||||
--- This will only run once and spawn a UI for users to input the envvar.
|
||||
---@param opts {refresh: boolean, provider: AvanteProviderFunctor}
|
||||
@@ -131,7 +132,8 @@ E._once = false
|
||||
E.setup = function(opts)
|
||||
local var = opts.provider.api_key_name
|
||||
|
||||
if var == M.AVANTE_INTERNAL_KEY then
|
||||
-- check if var is a all caps string
|
||||
if var == M.AVANTE_INTERNAL_KEY or var:match("^cmd:(.*)") then
|
||||
return
|
||||
end
|
||||
|
||||
@@ -149,51 +151,54 @@ E.setup = function(opts)
|
||||
end
|
||||
end
|
||||
|
||||
if refresh then
|
||||
local function mount_dressing_buffer()
|
||||
vim.defer_fn(function()
|
||||
Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm })
|
||||
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
|
||||
local exclude_buftypes = { "qf", "nofile" }
|
||||
local exclude_filetypes = {
|
||||
"NvimTree",
|
||||
"Outline",
|
||||
"help",
|
||||
"dashboard",
|
||||
"alpha",
|
||||
"qf",
|
||||
"ministarter",
|
||||
"TelescopePrompt",
|
||||
"gitcommit",
|
||||
"gitrebase",
|
||||
"DressingInput",
|
||||
}
|
||||
if
|
||||
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
|
||||
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
|
||||
and not opts.provider.has()
|
||||
then
|
||||
Dressing.initialize_input_buffer({
|
||||
opts = { prompt = "Enter " .. var .. ": " },
|
||||
on_confirm = on_confirm,
|
||||
})
|
||||
end
|
||||
end, 200)
|
||||
elseif not E._once then
|
||||
end
|
||||
|
||||
if refresh then
|
||||
mount_dressing_buffer()
|
||||
return
|
||||
end
|
||||
|
||||
if not E._once then
|
||||
E._once = true
|
||||
api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, {
|
||||
pattern = "*",
|
||||
once = true,
|
||||
callback = function()
|
||||
vim.defer_fn(function()
|
||||
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
|
||||
local exclude_buftypes = { "qf", "nofile" }
|
||||
local exclude_filetypes = {
|
||||
"NvimTree",
|
||||
"Outline",
|
||||
"help",
|
||||
"dashboard",
|
||||
"alpha",
|
||||
"qf",
|
||||
"ministarter",
|
||||
"TelescopePrompt",
|
||||
"gitcommit",
|
||||
"gitrebase",
|
||||
"DressingInput",
|
||||
}
|
||||
if
|
||||
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
|
||||
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
|
||||
and not opts.provider.has()
|
||||
then
|
||||
Dressing.initialize_input_buffer({
|
||||
opts = { prompt = "Enter " .. var .. ": " },
|
||||
on_confirm = on_confirm,
|
||||
})
|
||||
end
|
||||
end, 200)
|
||||
end,
|
||||
callback = mount_dressing_buffer,
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
---@param provider Provider
|
||||
E.is_local = function(provider)
|
||||
local cur = M.get(provider)
|
||||
local cur = M.get_config(provider)
|
||||
return cur["local"] ~= nil and cur["local"] or false
|
||||
end
|
||||
|
||||
@@ -201,14 +206,47 @@ M.env = E
|
||||
|
||||
M.AVANTE_INTERNAL_KEY = "__avante_env_internal"
|
||||
|
||||
M = setmetatable(M, {
|
||||
---@param t avante.Providers
|
||||
---@param k Provider
|
||||
__index = function(t, k)
|
||||
---@type AvanteProviderFunctor
|
||||
local Opts = M.get_config(k)
|
||||
|
||||
if Config.vendors[k] ~= nil then
|
||||
Opts.parse_response = Opts.parse_response_data
|
||||
t[k] = Opts
|
||||
else
|
||||
t[k] = vim.tbl_deep_extend("keep", Opts, require("avante.providers." .. k))
|
||||
end
|
||||
|
||||
t[k].parse_api_key = function()
|
||||
return E.parse_envvar(t[k])
|
||||
end
|
||||
|
||||
if t[k].has == nil then
|
||||
t[k].has = function()
|
||||
return E.parse_envvar(t[k]) ~= nil
|
||||
end
|
||||
end
|
||||
|
||||
if t[k].setup == nil then
|
||||
t[k].setup = function()
|
||||
t[k].parse_api_key()
|
||||
end
|
||||
end
|
||||
|
||||
return t[k]
|
||||
end,
|
||||
})
|
||||
|
||||
M.setup = function()
|
||||
---@type AvanteProviderFunctor
|
||||
local provider = M[Config.provider]
|
||||
E.setup({ provider = provider })
|
||||
|
||||
if provider.setup ~= nil then
|
||||
vim.schedule(function()
|
||||
provider.setup()
|
||||
end
|
||||
end)
|
||||
|
||||
M.commands()
|
||||
end
|
||||
@@ -216,6 +254,8 @@ end
|
||||
---@private
|
||||
---@param provider Provider
|
||||
function M.refresh(provider)
|
||||
require("avante.config").override({ provider = provider })
|
||||
|
||||
---@type AvanteProviderFunctor
|
||||
local p = M[Config.provider]
|
||||
if not p.has() then
|
||||
@@ -223,7 +263,6 @@ function M.refresh(provider)
|
||||
else
|
||||
Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" })
|
||||
end
|
||||
require("avante.config").override({ provider = provider })
|
||||
end
|
||||
|
||||
local default_providers = { "openai", "claude", "azure", "gemini", "copilot" }
|
||||
@@ -242,7 +281,8 @@ M.commands = function()
|
||||
end
|
||||
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
|
||||
-- join two tables
|
||||
local Keys = vim.list_extend(default_providers, vim.tbl_keys(Config.vendors or {}))
|
||||
local Keys = vim.list_extend({}, default_providers)
|
||||
Keys = vim.list_extend(Keys, vim.tbl_keys(Config.vendors or {}))
|
||||
return vim.tbl_filter(function(key)
|
||||
return key:find(prefix) == 1
|
||||
end, Keys)
|
||||
@@ -280,7 +320,8 @@ end
|
||||
|
||||
---@private
|
||||
---@param provider Provider
|
||||
M.get = function(provider)
|
||||
---@return AvanteProviderFunctor
|
||||
M.get_config = function(provider)
|
||||
local cur = Config.get_provider(provider or Config.provider)
|
||||
return type(cur) == "function" and cur() or cur
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user