diff --git a/README.md b/README.md
index 026d49b..3096545 100644
--- a/README.md
+++ b/README.md
@@ -258,6 +258,99 @@ lua_ls = {
Then you can set `dev = true` in your `lazy` config for development.
+## Custom Providers
+
+To add support for custom providers, one add `AvanteProvider` spec into `opts.vendors`:
+
+```lua
+{
+ provider = "my-custom-provider", -- You can then change this provider here
+ vendors = {
+ ["my-custom-provider"] = {...}
+ },
+ windows = {
+ wrap_line = true,
+ width = 30, -- default % based on available width
+ },
+ --- @class AvanteConflictUserConfig
+ diff = {
+ debug = false,
+ autojump = true,
+ ---@type string | fun(): any
+ list_opener = "copen",
+ },
+}
+
+```
+
+A custom provider should following the following spec:
+
+```lua
+---@type AvanteProvider
+{
+ endpoint = "https://api.openai.com/v1/chat/completions", -- The full endpoint of the provider
+ model = "gpt-4o", -- The model name to use with this provider
+ api_key_name = "OPENAI_API_KEY", -- The name of the environment variable that contains the API key
+ --- This function below will be used to parse in cURL arguments.
+ --- It takes in the provider options as the first argument, followed by code_opts retrieved from given buffer.
+ --- This code_opts include:
+ --- - question: Input from the users
+ --- - code_lang: the language of given code buffer
+ --- - code_content: content of code buffer
+ --- - selected_code_content: (optional) If given code content is selected in visual mode as context.
+ ---@type fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
+ parse_curl_args = function(opts, code_opts) end
+ --- This function will be used to parse incoming SSE stream
+ --- It takes in the data stream as the first argument, followed by opts retrieved from given buffer.
+ --- This opts include:
+ --- - on_chunk: (fun(chunk: string): any) this is invoked on parsing correct delta chunk
+ --- - on_complete: (fun(err: string|nil): any) this is invoked on either complete call or error chunk
+ --- - event_state: SSE event state.
+ ---@type fun(data_stream: string, opts: ResponseParser): nil
+ parse_response_data = function(data_stream, opts) end
+}
+```
+
+
+Full working example of perplexity
+
+```lua
+vendors = {
+ ---@type AvanteProvider
+ perplexity = {
+ endpoint = "https://api.perplexity.ai/chat/completions",
+ model = "llama-3.1-sonar-large-128k-online",
+ api_key_name = "PPLX_API_KEY",
+ --- this function below will be used to parse in cURL arguments.
+ parse_curl_args = function(opts, code_opts)
+ local Llm = require "avante.llm"
+ return {
+ url = opts.endpoint,
+ headers = {
+ ["Accept"] = "application/json",
+ ["Content-Type"] = "application/json",
+ ["Authorization"] = "Bearer " .. os.getenv(opts.api_key_name),
+ },
+ body = {
+ model = opts.model,
+ messages = Llm.make_openai_message(code_opts), -- you can make your own message, but this is very advanced
+ temperature = 0,
+ max_tokens = 8192,
+ stream = true, -- this will be set by default.
+ },
+ }
+ end,
+ -- The below function is used if the vendors has specific SSE spec that is not claude or openai.
+ parse_response_data = function(data_stream, opts)
+ local Llm = require "avante.llm"
+ Llm.parse_openai_response(data_stream, opts)
+ end,
+ },
+},
+```
+
+
+
## License
avante.nvim is licensed under the Apache License. For more details, please refer to the [LICENSE](./LICENSE) file.
diff --git a/lua/avante/config.lua b/lua/avante/config.lua
index 5e83c7a..43a3e9f 100644
--- a/lua/avante/config.lua
+++ b/lua/avante/config.lua
@@ -6,7 +6,7 @@ local M = {}
---@class avante.Config
M.defaults = {
- ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq"
+ ---@alias Provider "openai" | "claude" | "azure" | "deepseek" | "groq" | [string]
provider = "claude", -- "claude" or "openai" or "azure" or "deepseek" or "groq"
openai = {
endpoint = "https://api.openai.com",
@@ -39,6 +39,10 @@ M.defaults = {
temperature = 0,
max_tokens = 4096,
},
+ --- To add support for custom provider, follow the format below
+ --- See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details
+ ---@type table
+ vendors = {},
behaviour = {
auto_apply_diff_after_generation = false, -- Whether to automatically apply diff after LLM response.
},
@@ -100,6 +104,11 @@ function M.setup(opts)
)
end
+---@param opts? avante.Config
+function M.override(opts)
+ M.options = vim.tbl_deep_extend("force", M.options, opts or {})
+end
+
M = setmetatable(M, {
__index = function(_, k)
if M.options[k] then
diff --git a/lua/avante/init.lua b/lua/avante/init.lua
index f1251ed..aeada09 100644
--- a/lua/avante/init.lua
+++ b/lua/avante/init.lua
@@ -201,7 +201,7 @@ function M.setup(opts)
end
require("avante.diff").setup()
- require("avante.ai_bot").setup()
+ require("avante.llm").setup()
-- setup helpers
H.autocmds()
diff --git a/lua/avante/ai_bot.lua b/lua/avante/llm.lua
similarity index 83%
rename from lua/avante/ai_bot.lua
rename to lua/avante/llm.lua
index 400baf7..78adcab 100644
--- a/lua/avante/ai_bot.lua
+++ b/lua/avante/llm.lua
@@ -8,13 +8,13 @@ local Tiktoken = require("avante.tiktoken")
local Dressing = require("avante.ui.dressing")
---@private
----@class AvanteAiBotInternal
+---@class AvanteLLMInternal
local H = {}
----@class avante.AiBot
+---@class avante.LLM
local M = {}
-M.CANCEL_PATTERN = "AvanteAiBotEscape"
+M.CANCEL_PATTERN = "AvanteLLMEscape"
---@class EnvironmentHandler: table<[Provider], string>
local E = {
@@ -31,16 +31,41 @@ local E = {
E = setmetatable(E, {
---@param k Provider
__index = function(_, k)
- return os.getenv(E.env[k]) and true or false
+ local builtins = E.env[k]
+ if builtins then
+ return os.getenv(builtins) and true or false
+ end
+
+ local external = Config.vendors[k]
+ if external then
+ return os.getenv(external.api_key_name) and true or false
+ end
end,
})
+
+---@private
E._once = false
+E.is_default = function(provider)
+ return E.env[provider] and true or false
+end
+
--- return the environment variable name for the given provider
---@param provider? Provider
---@return string the envvar key
E.key = function(provider)
- return E.env[provider or Config.provider]
+ provider = provider or Config.provider
+
+ if E.is_default(provider) then
+ return E.env[provider]
+ end
+
+ local external = Config.vendors[provider]
+ if external then
+ return external.api_key_name
+ else
+ error("Failed to find provider: " .. provider, 2)
+ end
end
---@param provider? Provider
@@ -52,6 +77,7 @@ end
--- This will only run once and spawn a UI for users to input the envvar.
---@param var Provider supported providers
---@param refresh? boolean
+---@private
E.setup = function(var, refresh)
refresh = refresh or false
@@ -160,7 +186,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field code_content string
---@field selected_code_content? string
---
----@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): {role: "user" | "system", content: string | table}[]
+---@class AvanteBaseMessage
+---@field role "user" | "system"
+---@field content string
+---
+---@class AvanteClaudeMessage: AvanteBaseMessage
+---@field role "user"
+---@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[]
+---
+---@alias AvanteOpenAIMessage AvanteBaseMessage
+---
+---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage
+---
+---@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[]
---
---@class AvanteCurlOutput: {url: string, body: table | string, headers: table}
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
@@ -169,12 +207,19 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
---@field event_state string
---@field on_chunk fun(chunk: string): any
---@field on_complete fun(err: string|nil): any
----@field on_error? fun(err_type: string): nil
----@alias AvanteAiResponseParser fun(data_stream: string, opts: ResponseParser): nil
+---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
+---
+---@class AvanteProvider
+---@field endpoint string
+---@field model string
+---@field api_key_name string
+---@field parse_response_data AvanteResponseParser
+---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
------------------------------Anthropic------------------------------
----@type AvanteAiMessageBuilder
+---@param opts AvantePromptOptions
+---@return AvanteClaudeMessage[]
H.make_claude_message = function(opts)
local code_prompt_obj = {
type = "text",
@@ -232,7 +277,7 @@ H.make_claude_message = function(opts)
}
end
----@type AvanteAiResponseParser
+---@type AvanteResponseParser
H.parse_claude_response = function(data_stream, opts)
if opts.event_state == "content_block_delta" then
local json = vim.json.decode(data_stream)
@@ -268,7 +313,8 @@ end
------------------------------OpenAI------------------------------
----@type AvanteAiMessageBuilder
+---@param opts AvantePromptOptions
+---@return AvanteOpenAIMessage[]
H.make_openai_message = function(opts)
local user_prompt = base_user_prompt
.. "\n\nCODE:\n"
@@ -304,7 +350,7 @@ H.make_openai_message = function(opts)
}
end
----@type AvanteAiResponseParser
+---@type AvanteResponseParser
H.parse_openai_response = function(data_stream, opts)
if data_stream:match('"%[DONE%]":') then
opts.on_complete(nil)
@@ -346,7 +392,7 @@ end
---@type AvanteAiMessageBuilder
H.make_azure_message = H.make_openai_message
----@type AvanteAiResponseParser
+---@type AvanteResponseParser
H.parse_azure_response = H.parse_openai_response
---@type AvanteCurlArgsBuilder
@@ -375,7 +421,7 @@ end
---@type AvanteAiMessageBuilder
H.make_deepseek_message = H.make_openai_message
----@type AvanteAiResponseParser
+---@type AvanteResponseParser
H.parse_deepseek_response = H.parse_openai_response
---@type AvanteCurlArgsBuilder
@@ -401,7 +447,7 @@ end
---@type AvanteAiMessageBuilder
H.make_groq_message = H.make_openai_message
----@type AvanteAiResponseParser
+---@type AvanteResponseParser
H.parse_groq_response = H.parse_openai_response
---@type AvanteCurlArgsBuilder
@@ -424,7 +470,7 @@ end
------------------------------Logic------------------------------
-local group = vim.api.nvim_create_augroup("AvanteAiBot", { clear = true })
+local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true })
local active_job = nil
---@param question string
@@ -433,17 +479,35 @@ local active_job = nil
---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
-M.invoke_llm_stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
+M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
local provider = Config.provider
local event_state = nil
- ---@type AvanteCurlOutput
- local spec = H["make_" .. provider .. "_curl_args"]({
+ local code_opts = {
question = question,
code_lang = code_lang,
code_content = code_content,
selected_code_content = selected_content_content,
- })
+ }
+ local handler_opts = vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
+
+ ---@type AvanteCurlOutput
+ local spec = nil
+
+ ---@type AvanteProvider
+ local ProviderConfig = nil
+
+ if E.is_default(provider) then
+ spec = H["make_" .. provider .. "_curl_args"](code_opts)
+ else
+ ProviderConfig = Config.vendors[provider]
+ spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
+ end
+ if spec.body.stream == nil then
+ spec = vim.tbl_deep_extend("force", spec, {
+ body = { stream = true },
+ })
+ end
---@param line string
local function parse_and_call(line)
@@ -454,10 +518,11 @@ M.invoke_llm_stream = function(question, code_lang, code_content, selected_conte
end
local data_match = line:match("^data: (.+)$")
if data_match then
- H["parse_" .. provider .. "_response"](
- data_match,
- vim.deepcopy({ on_chunk = on_chunk, on_complete = on_complete, event_state = event_state }, true)
- )
+ if ProviderConfig ~= nil then
+ ProviderConfig.parse_response_data(data_match, handler_opts)
+ else
+ H["parse_" .. provider .. "_response"](data_match, handler_opts)
+ end
end
end
@@ -521,7 +586,7 @@ function M.refresh(provider)
else
vim.notify_once("Switch to provider: " .. provider, vim.log.levels.INFO)
end
- require("avante").setup({ provider = provider })
+ require("avante.config").override({ provider = provider })
end
M.commands = function()
@@ -536,11 +601,25 @@ M.commands = function()
return {}
end
local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or ""
+ -- join two tables
+ local Keys = vim.list_extend(vim.tbl_keys(E.env), vim.tbl_keys(Config.vendors))
return vim.tbl_filter(function(key)
return key:find(prefix) == 1
- end, vim.tbl_keys(E.env))
+ end, Keys)
end,
})
end
-return M
+return setmetatable(M, {
+ __index = function(t, k)
+ local h = H[k]
+ if h then
+ return H[k]
+ end
+ local v = t[k]
+ if v then
+ return t[k]
+ end
+ error("Failed to find key: " .. k)
+ end,
+})
diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua
index 38e530e..c5635ed 100644
--- a/lua/avante/sidebar.lua
+++ b/lua/avante/sidebar.lua
@@ -7,7 +7,7 @@ local N = require("nui-components")
local Config = require("avante.config")
local View = require("avante.view")
local Diff = require("avante.diff")
-local AiBot = require("avante.ai_bot")
+local Llm = require("avante.llm")
local Utils = require("avante.utils")
local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated"
@@ -141,7 +141,7 @@ function Sidebar:intialize()
mode = { "n" },
key = "q",
handler = function()
- api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
+ api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN })
self.renderer:close()
end,
},
@@ -149,7 +149,7 @@ function Sidebar:intialize()
mode = { "n" },
key = "",
handler = function()
- api.nvim_exec_autocmds("User", { pattern = AiBot.CANCEL_PATTERN })
+ api.nvim_exec_autocmds("User", { pattern = Llm.CANCEL_PATTERN })
self.renderer:close()
end,
},
@@ -245,6 +245,9 @@ end
---@param content string concatenated content of the buffer
---@param opts? {focus?: boolean, stream?: boolean, scroll?: boolean, callback?: fun(): nil} whether to focus the result view
function Sidebar:update_content(content, opts)
+ if not self.view.buf then
+ return
+ end
opts = vim.tbl_deep_extend("force", { focus = true, scroll = true, stream = false, callback = nil }, opts or {})
if opts.stream then
vim.schedule(function()
@@ -643,9 +646,16 @@ function Sidebar:render()
signal.is_loading = true
local state = signal:get_value()
local request = state.text
+ ---@type string
+ local model
- local provider_config = Config[Config.provider]
- local model = provider_config and provider_config.model or "default"
+ local builtins_provider_config = Config[Config.provider]
+ if builtins_provider_config ~= nil then
+ model = builtins_provider_config.model
+ else
+ local vendor_provider_config = Config.vendors[Config.provider]
+ model = vendor_provider_config and vendor_provider_config.model or "default"
+ end
local timestamp = get_timestamp()
@@ -670,50 +680,43 @@ function Sidebar:render()
local filetype = api.nvim_get_option_value("filetype", { buf = self.code.buf })
- AiBot.invoke_llm_stream(
- request,
- filetype,
- content_with_line_numbers,
- selected_code_content_with_line_numbers,
- function(chunk)
- signal.is_loading = true
- full_response = full_response .. chunk
- self:update_content(chunk, { stream = true, scroll = false })
- vim.schedule(function()
- vim.cmd("redraw")
- end)
- end,
- function(err)
- signal.is_loading = false
+ Llm.stream(request, filetype, content_with_line_numbers, selected_code_content_with_line_numbers, function(chunk)
+ signal.is_loading = true
+ full_response = full_response .. chunk
+ self:update_content(chunk, { stream = true, scroll = false })
+ vim.schedule(function()
+ vim.cmd("redraw")
+ end)
+ end, function(err)
+ signal.is_loading = false
- if err ~= nil then
- self:update_content(content_prefix .. full_response .. "\n\nšØ Error: " .. vim.inspect(err))
- return
- end
-
- -- Execute when the stream request is actually completed
- self:update_content(
- content_prefix
- .. full_response
- .. "\n\nššš **Generation complete!** Please review the code suggestions above.\n\n",
- {
- callback = function()
- api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
- end,
- }
- )
-
- -- Save chat history
- table.insert(chat_history or {}, {
- timestamp = timestamp,
- provider = Config.provider,
- model = model,
- request = request,
- response = full_response,
- })
- save_chat_history(self, chat_history)
+ if err ~= nil then
+ self:update_content(content_prefix .. full_response .. "\n\nšØ Error: " .. vim.inspect(err))
+ return
end
- )
+
+ -- Execute when the stream request is actually completed
+ self:update_content(
+ content_prefix
+ .. full_response
+ .. "\n\nššš **Generation complete!** Please review the code suggestions above.\n\n",
+ {
+ callback = function()
+ api.nvim_exec_autocmds("User", { pattern = VIEW_BUFFER_UPDATED_PATTERN })
+ end,
+ }
+ )
+
+ -- Save chat history
+ table.insert(chat_history or {}, {
+ timestamp = timestamp,
+ provider = Config.provider,
+ model = model,
+ request = request,
+ response = full_response,
+ })
+ save_chat_history(self, chat_history)
+ end)
if Config.behaviour.auto_apply_diff_after_generation then
apply()