diff --git a/README.md b/README.md index 026d49b..3096545 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,99 @@ lua_ls = { Then you can set `dev = true` in your `lazy` config for development. +## Custom Providers + +To add support for custom providers, one add `AvanteProvider` spec into `opts.vendors`: + +```lua +{ + provider = "my-custom-provider", -- You can then change this provider here + vendors = { + ["my-custom-provider"] = {...} + }, + windows = { + wrap_line = true, + width = 30, -- default % based on available width + }, + --- @class AvanteConflictUserConfig + diff = { + debug = false, + autojump = true, + ---@type string | fun(): any + list_opener = "copen", + }, +} + +``` + +A custom provider should following the following spec: + +```lua +---@type AvanteProvider +{ + endpoint = "https://api.openai.com/v1/chat/completions", -- The full endpoint of the provider + model = "gpt-4o", -- The model name to use with this provider + api_key_name = "OPENAI_API_KEY", -- The name of the environment variable that contains the API key + --- This function below will be used to parse in cURL arguments. + --- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer. + --- This code_opts include: + --- - question: Input from the users + --- - code_lang: the language of given code buffer + --- - code_content: content of code buffer + --- - selected_code_content: (optional) If given code content is selected in visual mode as context. + ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput + parse_curl_args = function(opts, code_opts) end + --- This function will be used to parse incoming SSE stream + --- It takes in the data stream as the first argument, followed by opts retrieved from given buffer. + --- This opts include: + --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk + --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk + --- - event_state: SSE event state. + ---@type fun(data_stream: string, opts: ResponseParser): nil + parse_response_data = function(data_stream, opts) end +} +``` + +
+Full working example of perplexity + +```lua +vendors = { + ---@type AvanteProvider + perplexity = { + endpoint = "https://api.perplexity.ai/chat/completions", + model = "llama-3.1-sonar-large-128k-online", + api_key_name = "PPLX_API_KEY", + --- this function below will be used to parse in cURL arguments. + parse_curl_args = function(opts, code_opts) + local Llm = require "avante.llm" + return { + url = opts.endpoint, + headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + ["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name), + }, + body = { + model = opts.model, + messages = Llm.make_openai_message(code_opts), -- you can make your own message, but this is very advanced + temperature = 0, + max_tokens = 8192, + stream = true, -- this will be set by default. + }, + } + end, + -- The below function is used if the vendors has specific SSE spec that is not claude or openai. + parse_response_data = function(data_stream, opts) + local Llm = require "avante.llm" + Llm.parse_openai_response(data_stream, opts) + end, + }, +}, +``` + +
+ ## License avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file. diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 5e83c7a..43a3e9f 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -6,7 +6,7 @@ local M = {} ---@class avante.Config M.defaults = { - ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" + ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string] provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq" openai = { endpoint = "https://api.openai.com", @@ -39,6 +39,10 @@ M.defaults = { temperature = 0, max_tokens = 4096, }, + --- To add support for custom provider, follow the format below + --- See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details + ---@type table + vendors = {}, behaviour = { auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response. }, @@ -100,6 +104,11 @@ function M.setup(opts) ) end +---@param opts? avante.Config +function M.override(opts) + M.options = vim.tbl_deep_extend("force", M.options, opts or {}) +end + M = setmetatable(M, { __index = function(_, k) if M.options[k] then diff --git a/lua/avante/init.lua b/lua/avante/init.lua index f1251ed..aeada09 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -201,7 +201,7 @@ function M.setup(opts) end require("avante.diff").setup() - require("avante.ai_bot").setup() + require("avante.llm").setup() -- setup helpers H.autocmds() diff --git a/lua/avante/ai_bot.lua b/lua/avante/llm.lua similarity index 83% rename from lua/avante/ai_bot.lua rename to lua/avante/llm.lua index 400baf7..78adcab 100644 --- a/lua/avante/ai_bot.lua +++ b/lua/avante/llm.lua @@ -8,13 +8,13 @@ local Tiktoken = require("avante.tiktoken") local Dressing = require("avante.ui.dressing") ---@private ----@class AvanteAiBotInternal +---@class AvanteLLMInternal local H = {} ----@class avante.AiBot +---@class avante.LLM local M = {} -M.CANCEL_PATTERN = "AvanteAiBotEscape" +M.CANCEL_PATTERN = "AvanteLLMEscape" ---@class EnvironmentHandler: table<[Provider], string> local E = { @@ -31,16 +31,41 @@ local E = { E = setmetatable(E, { ---@param k Provider __index = function(_, k) - return os.getenv(E.env[k]) and true or false + local builtins = E.env[k] + if builtins then + return os.getenv(builtins) and true or false + end + + local external = Config.vendors[k] + if external then + return os.getenv(external.api_key_name) and true or false + end end, }) + +---@private E._once = false +E.is_default = function(provider) + return E.env[provider] and true or false +end + --- return the environment variable name for the given provider ---@param provider? Provider ---@return string the envvar key E.key = function(provider) - return E.env[provider or Config.provider] + provider = provider or Config.provider + + if E.is_default(provider) then + return E.env[provider] + end + + local external = Config.vendors[provider] + if external then + return external.api_key_name + else + error("Failed to find provider: " .. provider, 2) + end end ---@param provider? Provider @@ -52,6 +77,7 @@ end --- This will only run once and spawn a UI for users to input the envvar. ---@param var Provider supported providers ---@param refresh? boolean +---@private E.setup = function(var, refresh) refresh = refresh or false @@ -160,7 +186,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@field code_content string ---@field selected_code_content? string --- ----@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): {role: "user" | "system", content: string | table}[] +---@class AvanteBaseMessage +---@field role "user" | "system" +---@field content string +--- +---@class AvanteClaudeMessage: AvanteBaseMessage +---@field role "user" +---@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[] +--- +---@alias AvanteOpenAIMessage AvanteBaseMessage +--- +---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage +--- +---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[] --- ---@class AvanteCurlOutput: {url: string, body: table | string, headers: table} ---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput @@ -169,12 +207,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m ---@field event_state string ---@field on_chunk fun(chunk: string): any ---@field on_complete fun(err: string|nil): any ----@field on_error? fun(err_type: string): nil ----@alias AvanteAiResponseParser fun(data_stream: string, opts: ResponseParser): nil +---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil +--- +---@class AvanteProvider +---@field endpoint string +---@field model string +---@field api_key_name string +---@field parse_response_data AvanteResponseParser +---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput ------------------------------Anthropic------------------------------ ----@type AvanteAiMessageBuilder +---@param opts AvantePromptOptions +---@return AvanteClaudeMessage[] H.make_claude_message = function(opts) local code_prompt_obj = { type = "text", @@ -232,7 +277,7 @@ H.make_claude_message = function(opts) } end ----@type AvanteAiResponseParser +---@type AvanteResponseParser H.parse_claude_response = function(data_stream, opts) if opts.event_state == "content_block_delta" then local json = vim.json.decode(data_stream) @@ -268,7 +313,8 @@ end ------------------------------OpenAI------------------------------ ----@type AvanteAiMessageBuilder +---@param opts AvantePromptOptions +---@return AvanteOpenAIMessage[] H.make_openai_message = function(opts) local user_prompt = base_user_prompt .. "\n\nCODE:\n" @@ -304,7 +350,7 @@ H.make_openai_message = function(opts) } end ----@type AvanteAiResponseParser +---@type AvanteResponseParser H.parse_openai_response = function(data_stream, opts) if data_stream:match('"%[DONE%]":') then opts.on_complete(nil) @@ -346,7 +392,7 @@ end ---@type AvanteAiMessageBuilder H.make_azure_message = H.make_openai_message ----@type AvanteAiResponseParser +---@type AvanteResponseParser H.parse_azure_response = H.parse_openai_response ---@type AvanteCurlArgsBuilder @@ -375,7 +421,7 @@ end ---@type AvanteAiMessageBuilder H.make_deepseek_message = H.make_openai_message ----@type AvanteAiResponseParser +---@type AvanteResponseParser H.parse_deepseek_response = H.parse_openai_response ---@type AvanteCurlArgsBuilder @@ -401,7 +447,7 @@ end ---@type AvanteAiMessageBuilder H.make_groq_message = H.make_openai_message ----@type AvanteAiResponseParser +---@type AvanteResponseParser H.parse_groq_response = H.parse_openai_response ---@type AvanteCurlArgsBuilder @@ -424,7 +470,7 @@ end ------------------------------Logic------------------------------ -local group = vim.api.nvim_create_augroup("AvanteAiBot", { clear = true }) +local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true }) local active_job = nil ---@param question string @@ -433,17 +479,35 @@ local active_job = nil ---@param selected_content_content string | nil ---@param on_chunk fun(chunk: string): any ---@param on_complete fun(err: string|nil): any -M.invoke_llm_stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) +M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) local provider = Config.provider local event_state = nil - ---@type AvanteCurlOutput - local spec = H["make_" .. provider .. "_curl_args"]({ + local code_opts = { question = question, code_lang = code_lang, code_content = code_content, selected_code_content = selected_content_content, - }) + } + local handler_opts = vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true) + + ---@type AvanteCurlOutput + local spec = nil + + ---@type AvanteProvider + local ProviderConfig = nil + + if E.is_default(provider) then + spec = H["make_" .. provider .. "_curl_args"](code_opts) + else + ProviderConfig = Config.vendors[provider] + spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts) + end + if spec.body.stream == nil then + spec = vim.tbl_deep_extend("force", spec, { + body = { stream = true }, + }) + end ---@param line string local function parse_and_call(line) @@ -454,10 +518,11 @@ M.invoke_llm_stream = function(question, code_lang, code_content, selected_conte end local data_match = line:match("^data: (.+)$") if data_match then - H["parse_" .. provider .. "_response"]( - data_match, - vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true) - ) + if ProviderConfig ~= nil then + ProviderConfig.parse_response_data(data_match, handler_opts) + else + H["parse_" .. provider .. "_response"](data_match, handler_opts) + end end end @@ -521,7 +586,7 @@ function M.refresh(provider) else vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO) end - require("avante").setup({ provider = provider }) + require("avante.config").override({ provider = provider }) end M.commands = function() @@ -536,11 +601,25 @@ M.commands = function() return {} end local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or "" + -- join two tables + local Keys = vim.list_extend(vim.tbl_keys(E.env), vim.tbl_keys(Config.vendors)) return vim.tbl_filter(function(key) return key:find(prefix) == 1 - end, vim.tbl_keys(E.env)) + end, Keys) end, }) end -return M +return setmetatable(M, { + __index = function(t, k) + local h = H[k] + if h then + return H[k] + end + local v = t[k] + if v then + return t[k] + end + error("Failed to find key: " .. k) + end, +}) diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 38e530e..c5635ed 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -7,7 +7,7 @@ local N = require("nui-components") local Config = require("avante.config") local View = require("avante.view") local Diff = require("avante.diff") -local AiBot = require("avante.ai_bot") +local Llm = require("avante.llm") local Utils = require("avante.utils") local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" @@ -141,7 +141,7 @@ function Sidebar:intialize() mode = { "n" }, key = "q", handler = function() - api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN }) + api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN }) self.renderer:close() end, }, @@ -149,7 +149,7 @@ function Sidebar:intialize() mode = { "n" }, key = "", handler = function() - api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN }) + api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN }) self.renderer:close() end, }, @@ -245,6 +245,9 @@ end ---@param content string concatenated content of the buffer ---@param opts? {focus?: boolean, stream?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view function Sidebar:update_content(content, opts) + if not self.view.buf then + return + end opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {}) if opts.stream then vim.schedule(function() @@ -643,9 +646,16 @@ function Sidebar:render() signal.is_loading = true local state = signal:get_value() local request = state.text + ---@type string + local model - local provider_config = Config[Config.provider] - local model = provider_config and provider_config.model or "default" + local builtins_provider_config = Config[Config.provider] + if builtins_provider_config ~= nil then + model = builtins_provider_config.model + else + local vendor_provider_config = Config.vendors[Config.provider] + model = vendor_provider_config and vendor_provider_config.model or "default" + end local timestamp = get_timestamp() @@ -670,50 +680,43 @@ function Sidebar:render() local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf }) - AiBot.invoke_llm_stream( - request, - filetype, - content_with_line_numbers, - selected_code_content_with_line_numbers, - function(chunk) - signal.is_loading = true - full_response = full_response .. chunk - self:update_content(chunk, { stream = true, scroll = false }) - vim.schedule(function() - vim.cmd("redraw") - end) - end, - function(err) - signal.is_loading = false + Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk) + signal.is_loading = true + full_response = full_response .. chunk + self:update_content(chunk, { stream = true, scroll = false }) + vim.schedule(function() + vim.cmd("redraw") + end) + end, function(err) + signal.is_loading = false - if err ~= nil then - self:update_content(content_prefix .. full_response .. "\n\n🚨 Error: " .. vim.inspect(err)) - return - end - - -- Execute when the stream request is actually completed - self:update_content( - content_prefix - .. full_response - .. "\n\nšŸŽ‰šŸŽ‰šŸŽ‰ **Generation complete!** Please review the code suggestions above.\n\n", - { - callback = function() - api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) - end, - } - ) - - -- Save chat history - table.insert(chat_history or {}, { - timestamp = timestamp, - provider = Config.provider, - model = model, - request = request, - response = full_response, - }) - save_chat_history(self, chat_history) + if err ~= nil then + self:update_content(content_prefix .. full_response .. "\n\n🚨 Error: " .. vim.inspect(err)) + return end - ) + + -- Execute when the stream request is actually completed + self:update_content( + content_prefix + .. full_response + .. "\n\nšŸŽ‰šŸŽ‰šŸŽ‰ **Generation complete!** Please review the code suggestions above.\n\n", + { + callback = function() + api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN }) + end, + } + ) + + -- Save chat history + table.insert(chat_history or {}, { + timestamp = timestamp, + provider = Config.provider, + model = model, + request = request, + response = full_response, + }) + save_chat_history(self, chat_history) + end) if Config.behaviour.auto_apply_diff_after_generation then apply()