|
|
|
|
@@ -1,6 +1,9 @@
|
|
|
|
|
local fn = vim.fn
|
|
|
|
|
local api = vim.api
|
|
|
|
|
|
|
|
|
|
local curl = require("plenary.curl")
|
|
|
|
|
local Input = require("nui.input")
|
|
|
|
|
local Event = require("nui.utils.autocmd").event
|
|
|
|
|
|
|
|
|
|
local Utils = require("avante.utils")
|
|
|
|
|
local Config = require("avante.config")
|
|
|
|
|
@@ -9,6 +12,153 @@ local Tiktoken = require("avante.tiktoken")
|
|
|
|
|
---@class avante.AiBot
|
|
|
|
|
local M = {}
|
|
|
|
|
|
|
|
|
|
---@class Environment: table<[string], any>
|
|
|
|
|
---@field [string] string the environment variable name
|
|
|
|
|
---@field fallback? string Optional fallback API key environment variable name
|
|
|
|
|
|
|
|
|
|
---@class EnvironmentHandler: table<[Provider], string>
|
|
|
|
|
local E = {
|
|
|
|
|
---@type table<Provider, Environment | string>
|
|
|
|
|
env = {
|
|
|
|
|
openai = "OPENAI_API_KEY",
|
|
|
|
|
claude = "ANTHROPIC_API_KEY",
|
|
|
|
|
azure = { "AZURE_OPENAI_API_KEY", fallback = "OPENAI_API_KEY" },
|
|
|
|
|
},
|
|
|
|
|
_once = false,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
E = setmetatable(E, {
|
|
|
|
|
---@param k Provider
|
|
|
|
|
__index = function(_, k)
|
|
|
|
|
local envvar = E.env[k]
|
|
|
|
|
if type(envvar) == "string" then
|
|
|
|
|
local value = os.getenv(envvar)
|
|
|
|
|
return value and true or false
|
|
|
|
|
elseif type(envvar) == "table" then
|
|
|
|
|
local main_key = envvar[1]
|
|
|
|
|
local value = os.getenv(main_key)
|
|
|
|
|
if value then
|
|
|
|
|
return true
|
|
|
|
|
elseif envvar.fallback then
|
|
|
|
|
return os.getenv(envvar.fallback) and true or false
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return false
|
|
|
|
|
end,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
-- courtesy of https://github.com/MunifTanjim/nui.nvim/wiki/nui.input
|
|
|
|
|
local SecretInput = Input:extend("SecretInput")
|
|
|
|
|
|
|
|
|
|
function SecretInput:init(popup_options, options)
|
|
|
|
|
assert(
|
|
|
|
|
not options.conceal_char or vim.api.nvim_strwidth(options.conceal_char) == 1,
|
|
|
|
|
"conceal_char must be a single char"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
popup_options.win_options = vim.tbl_deep_extend("force", popup_options.win_options or {}, {
|
|
|
|
|
conceallevel = 2,
|
|
|
|
|
concealcursor = "nvi",
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
SecretInput.super.init(self, popup_options, options)
|
|
|
|
|
|
|
|
|
|
self._.conceal_char = type(options.conceal_char) == "nil" and "*" or options.conceal_char
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function SecretInput:mount()
|
|
|
|
|
SecretInput.super.mount(self)
|
|
|
|
|
|
|
|
|
|
local conceal_char = self._.conceal_char
|
|
|
|
|
local prompt_length = vim.api.nvim_strwidth(vim.fn.prompt_getprompt(self.bufnr))
|
|
|
|
|
|
|
|
|
|
vim.api.nvim_buf_call(self.bufnr, function()
|
|
|
|
|
vim.cmd(string.format(
|
|
|
|
|
[[
|
|
|
|
|
syn region SecretValue start=/^/ms=s+%s end=/$/ contains=SecretChar
|
|
|
|
|
syn match SecretChar /./ contained conceal %s
|
|
|
|
|
]],
|
|
|
|
|
prompt_length,
|
|
|
|
|
conceal_char and "cchar=" .. (conceal_char or "*") or ""
|
|
|
|
|
))
|
|
|
|
|
end)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
--- return the environment variable name for the given provider
|
|
|
|
|
---@param provider? Provider
|
|
|
|
|
---@return string the envvar key
|
|
|
|
|
E.key = function(provider)
|
|
|
|
|
provider = provider or Config.provider
|
|
|
|
|
local var = E.env[provider]
|
|
|
|
|
return type(var) == "table" and var[1] ---@cast var string
|
|
|
|
|
or var
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
E.setup = function(var)
|
|
|
|
|
if E._once then
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
local input = SecretInput({
|
|
|
|
|
position = "50%",
|
|
|
|
|
size = {
|
|
|
|
|
width = 40,
|
|
|
|
|
},
|
|
|
|
|
border = {
|
|
|
|
|
style = "single",
|
|
|
|
|
text = {
|
|
|
|
|
top = "Enter " .. var,
|
|
|
|
|
top_align = "center",
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
win_options = {
|
|
|
|
|
winhighlight = "Normal:Normal,FloatBorder:Normal",
|
|
|
|
|
},
|
|
|
|
|
}, {
|
|
|
|
|
prompt = "> ",
|
|
|
|
|
default_value = "",
|
|
|
|
|
on_submit = function(value)
|
|
|
|
|
vim.fn.setenv(var, value)
|
|
|
|
|
end,
|
|
|
|
|
on_close = function()
|
|
|
|
|
if not E[Config.provider] then
|
|
|
|
|
vim.notify_once("Failed to set " .. var .. ". Avante won't work as expected", vim.log.levels.WARN)
|
|
|
|
|
end
|
|
|
|
|
end,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
api.nvim_create_autocmd({ "BufEnter", "BufWinEnter" }, {
|
|
|
|
|
pattern = "*",
|
|
|
|
|
callback = function()
|
|
|
|
|
if E._once then
|
|
|
|
|
return
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
vim.defer_fn(function()
|
|
|
|
|
-- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf
|
|
|
|
|
local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" }
|
|
|
|
|
local exclude_filetypes =
|
|
|
|
|
{ "NvimTree", "Outline", "help", "dashboard", "alpha", "qf", "ministarter", "TelescopePrompt", "gitcommit" }
|
|
|
|
|
if
|
|
|
|
|
not vim.tbl_contains(exclude_buftypes, vim.bo.buftype)
|
|
|
|
|
and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype)
|
|
|
|
|
then
|
|
|
|
|
E._once = true
|
|
|
|
|
input:mount()
|
|
|
|
|
end
|
|
|
|
|
end, 200)
|
|
|
|
|
end,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
input:map("n", "<Esc>", function()
|
|
|
|
|
input:unmount()
|
|
|
|
|
end, { noremap = true })
|
|
|
|
|
|
|
|
|
|
input:on(Event.BufLeave, function()
|
|
|
|
|
input:unmount()
|
|
|
|
|
end)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
local system_prompt = [[
|
|
|
|
|
You are an excellent programming expert.
|
|
|
|
|
]]
|
|
|
|
|
@@ -57,10 +207,7 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
|
|
|
|
]]
|
|
|
|
|
|
|
|
|
|
local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
|
|
|
|
|
local api_key = os.getenv("ANTHROPIC_API_KEY")
|
|
|
|
|
if not api_key then
|
|
|
|
|
error("ANTHROPIC_API_KEY environment variable is not set")
|
|
|
|
|
end
|
|
|
|
|
local api_key = os.getenv(E.key("azure"))
|
|
|
|
|
|
|
|
|
|
local tokens = Config.claude.max_tokens
|
|
|
|
|
local headers = {
|
|
|
|
|
@@ -174,11 +321,7 @@ local function call_claude_api_stream(question, code_lang, code_content, selecte
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
|
|
|
|
|
local api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
if not api_key and Config.provider == "openai" then
|
|
|
|
|
error("OPENAI_API_KEY environment variable is not set")
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
local api_key = os.getenv(E.key("openai"))
|
|
|
|
|
local user_prompt = base_user_prompt
|
|
|
|
|
.. "\n\nCODE:\n"
|
|
|
|
|
.. "```"
|
|
|
|
|
@@ -209,10 +352,7 @@ local function call_openai_api_stream(question, code_lang, code_content, selecte
|
|
|
|
|
|
|
|
|
|
local url, headers, body
|
|
|
|
|
if Config.provider == "azure" then
|
|
|
|
|
api_key = os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY")
|
|
|
|
|
if not api_key then
|
|
|
|
|
error("Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable.")
|
|
|
|
|
end
|
|
|
|
|
api_key = os.getenv(E.key("azure"))
|
|
|
|
|
url = Config.azure.endpoint
|
|
|
|
|
.. "/openai/deployments/"
|
|
|
|
|
.. Config.azure.deployment
|
|
|
|
|
@@ -306,4 +446,11 @@ function M.call_ai_api_stream(question, code_lang, code_content, selected_conten
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function M.setup()
|
|
|
|
|
local has = E[Config.provider]
|
|
|
|
|
if not has then
|
|
|
|
|
E.setup(E.key())
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
return M
|
|
|
|
|
|