From 8d52229f16be9d132abcb64fa157189667eaba4a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 22 Aug 2024 01:48:40 -0400 Subject: [PATCH] refactor(llm): cleanup providers for future ops (closes #134) (#147) support allow_insecure and proxy ops Signed-off-by: Aaron Pham --- README.md | 2 +- lua/avante/config.lua | 49 +- lua/avante/init.lua | 2 +- lua/avante/llm.lua | 776 +----------------------------- lua/avante/providers/azure.lua | 44 ++ lua/avante/providers/claude.lua | 112 +++++ lua/avante/providers/copilot.lua | 229 +++++++++ lua/avante/providers/deepseek.lua | 41 ++ lua/avante/providers/gemini.lua | 79 +++ lua/avante/providers/groq.lua | 41 ++ lua/avante/providers/init.lua | 282 +++++++++++ lua/avante/providers/openai.lua | 110 +++++ lua/avante/sidebar.lua | 4 +- lua/avante/utils/copilot.lua | 109 ----- lua/avante/utils/init.lua | 1 - 15 files changed, 1007 insertions(+), 874 deletions(-) create mode 100644 lua/avante/providers/azure.lua create mode 100644 lua/avante/providers/claude.lua create mode 100644 lua/avante/providers/copilot.lua create mode 100644 lua/avante/providers/deepseek.lua create mode 100644 lua/avante/providers/gemini.lua create mode 100644 lua/avante/providers/groq.lua create mode 100644 lua/avante/providers/init.lua create mode 100644 lua/avante/providers/openai.lua delete mode 100644 lua/avante/utils/copilot.lua diff --git a/README.md b/README.md index efc3eea..a7fffb1 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_ }, hints = { enabled = true }, windows = { - wrap_line = true, -- similar to vim.o.wrap + wrap = true, -- similar to vim.o.wrap width = 30, -- default % based on available width sidebar_header = { align = "center", -- left, center, right for title diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 08d68ca..5650471 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -42,9 +42,9 @@ M.defaults = { claude = { endpoint = "https://api.anthropic.com", model = "claude-3-5-sonnet-20240620", + ["local"] = false, temperature = 0, max_tokens = 4096, - ["local"] = false, }, ---@type AvanteSupportedProvider deepseek = { @@ -64,10 +64,9 @@ M.defaults = { }, ---@type AvanteGeminiProvider gemini = { - endpoint = "", - type = "gemini", + endpoint = "https://generativelanguage.googleapis.com/v1beta/models", model = "gemini-1.5-pro", - options = {}, + ["local"] = false, }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/README.md#custom-providers for more details @@ -105,12 +104,15 @@ M.defaults = { }, }, windows = { - wrap_line = true, -- similar to vim.o.wrap + wrap = true, -- similar to vim.o.wrap width = 30, -- default % based on available width sidebar_header = { align = "center", -- left, center, right for title rounded = true, }, + prompt = { + prefix = "> ", -- prefix for the prompt + }, }, --- @class AvanteConflictUserConfig diff = { @@ -148,11 +150,25 @@ function M.setup(opts) { mappings = M.options.mappings.diff, highlights = M.options.highlights.diff } ) M.hints = vim.tbl_deep_extend("force", {}, M.options.hints) + + if next(M.options.vendors) ~= nil then + for k, v in pairs(M.options.vendors) do + M.options.vendors[k] = type(v) == "function" and v() or v + end + end end ---@param opts? avante.Config function M.override(opts) - M.options = vim.tbl_deep_extend("force", M.options, opts or {}) + opts = opts or {} + M.options = vim.tbl_deep_extend("force", M.options, opts) + M.diff = vim.tbl_deep_extend( + "force", + {}, + M.options.diff, + { mappings = M.options.mappings.diff, highlights = M.options.highlights.diff } + ) + M.hints = vim.tbl_deep_extend("force", {}, M.options.hints) end M = setmetatable(M, { @@ -167,6 +183,27 @@ function M.get_window_width() return math.ceil(vim.o.columns * (M.windows.width / 100)) end +---@param provider Provider +---@return boolean +M.has_provider = function(provider) + return M.options[provider] ~= nil or M.vendors[provider] ~= nil +end + +---get supported providers +---@param provider Provider +---@return AvanteProvider | fun(): AvanteProvider +M.get_provider = function(provider) + if M.options[provider] ~= nil then + return vim.deepcopy(M.options[provider], true) + elseif M.vendors[provider] ~= nil then + return vim.deepcopy(M.vendors[provider], true) + else + error("Failed to find provider: " .. provider, 2) + end +end + +M.BASE_PROVIDER_KEYS = { "endpoint", "model", "local", "deployment", "api_version", "proxy", "allow_insecure" } + ---@return {width: integer, height: integer} function M.get_sidebar_layout_options() local width = M.get_window_width() diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 7393c19..4e1fff4 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -207,7 +207,7 @@ function M.setup(opts) require("avante.highlights").setup() require("avante.diff").setup() - require("avante.llm").setup() + require("avante.providers").setup() -- setup helpers H.autocmds() diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 73d6efc..a927d1e 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -4,209 +4,21 @@ local curl = require("plenary.curl") local Utils = require("avante.utils") local Config = require("avante.config") -local Tiktoken = require("avante.tiktoken") -local Dressing = require("avante.ui.dressing") +local P = require("avante.providers") ---@class avante.LLM local M = {} M.CANCEL_PATTERN = "AvanteLLMEscape" ----@class CopilotToken ----@field annotations_enabled boolean ----@field chat_enabled boolean ----@field chat_jetbrains_enabled boolean ----@field code_quote_enabled boolean ----@field codesearch boolean ----@field copilotignore_enabled boolean ----@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string} ----@field expires_at integer ----@field individual boolean ----@field nes_enabled boolean ----@field prompt_8k boolean ----@field public_suggestions string ----@field refresh_in integer ----@field sku string ----@field snippy_load_test_enabled boolean ----@field telemetry string ----@field token string ----@field tracking_id string ----@field vsc_electron_fetcher boolean ----@field xcode boolean ----@field xcode_chat boolean ---- ----@private ----@class AvanteCopilot: table ----@field proxy string ----@field allow_insecure boolean ----@field token? CopilotToken ----@field github_token? string ----@field sessionid? string ----@field machineid? string -M.copilot = nil - ----@class EnvironmentHandler: table<[Provider], string> -local E = { - ---@type table - env = { - openai = "OPENAI_API_KEY", - claude = "ANTHROPIC_API_KEY", - azure = "AZURE_OPENAI_API_KEY", - deepseek = "DEEPSEEK_API_KEY", - groq = "GROQ_API_KEY", - gemini = "GEMINI_API_KEY", - copilot = function() - if Utils.has("copilot.lua") or Utils.has("copilot.vim") or Utils.copilot.find_config_path() then - return true - end - Utils.warn("copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication.") - return false - end, - }, -} - -setmetatable(E, { - ---@param k Provider - __index = function(_, k) - if E.is_local(k) then - return true - end - - local builtins = E.env[k] - if builtins then - if type(builtins) == "function" then - return builtins() - end - return os.getenv(builtins) and true or false - end - - ---@type AvanteProvider | nil - local external = Config.vendors[k] - if external then - return os.getenv(external.api_key_name) and true or false - end - end, -}) - ----@private -E._once = false - ----@param provider Provider -E.is_default = function(provider) - return E.env[provider] and true or false -end - -local AVANTE_INTERNAL_KEY = "__avante_internal" - ---- return the environment variable name for the given provider ----@param provider? Provider ----@return string the envvar key -E.key = function(provider) - provider = provider or Config.provider - - if E.is_default(provider) then - local result = E.env[provider] - return type(result) == "function" and AVANTE_INTERNAL_KEY or result - end - - ---@type AvanteProvider | nil - local external = Config.vendors[provider] - if external then - return external.api_key_name - end - error("Failed to find provider: " .. provider, 2) -end - ----@param provider Provider -E.is_local = function(provider) - if Config.options[provider] then - return Config.options[provider]["local"] - elseif Config.vendors[provider] then - return Config.vendors[provider]["local"] - else - return false - end -end - ----@param provider? Provider -E.value = function(provider) - if E.is_local(provider or Config.provider) then - return "__avante_dummy" - end - return os.getenv(E.key(provider or Config.provider)) -end - ---- intialize the environment variable for current neovim session. ---- This will only run once and spawn a UI for users to input the envvar. ----@param var string supported providers ----@param refresh? boolean ----@private -E.setup = function(var, refresh) - if var == AVANTE_INTERNAL_KEY then - return - end - - refresh = refresh or false - - ---@param value string - ---@return nil - local function on_confirm(value) - if value then - vim.fn.setenv(var, value) - else - if not E[Config.provider] then - Utils.warn("Failed to set " .. var .. ". Avante won't work as expected", { once = true, title = "Avante" }) - end - end - end - - if refresh then - vim.defer_fn(function() - Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm }) - end, 200) - elseif not E._once then - E._once = true - api.nvim_create_autocmd({ "BufEnter", "BufWinEnter" }, { - pattern = "*", - once = true, - callback = function() - vim.defer_fn(function() - -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf - local exclude_buftypes = { "dashboard", "alpha", "qf", "nofile" } - local exclude_filetypes = { - "NvimTree", - "Outline", - "help", - "dashboard", - "alpha", - "qf", - "ministarter", - "TelescopePrompt", - "gitcommit", - "gitrebase", - "DressingInput", - } - if - not vim.tbl_contains(exclude_buftypes, vim.bo.buftype) - and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype) - then - Dressing.initialize_input_buffer({ - opts = { prompt = "Enter " .. var .. ": " }, - on_confirm = on_confirm, - }) - end - end, 200) - end, - }) - end -end - ------------------------------Prompt and type------------------------------ +---@alias AvanteSystemPrompt string local system_prompt = [[ You are an excellent programming expert. ]] +---@alias AvanteBasePrompt string local base_user_prompt = [[ Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: @@ -251,490 +63,6 @@ Replace lines: {{start_line}}-{{end_line}} Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. ]] ----@class AvanteHandlerOptions: table<[string], string> ----@field on_chunk AvanteChunkParser ----@field on_complete AvanteCompleteParser ---- ----@class AvantePromptOptions: table<[string], string> ----@field question string ----@field code_lang string ----@field code_content string ----@field selected_code_content? string ---- ----@class AvanteBaseMessage ----@field role "user" | "system" ----@field content string ---- ----@class AvanteClaudeMessage: AvanteBaseMessage ----@field role "user" ----@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[] ---- ----@alias AvanteOpenAIMessage AvanteBaseMessage ---- ----@class AvanteGeminiMessage ----@field role "user" ----@field parts { text: string }[] ---- ----@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage ---- ----@alias AvanteAiMessageBuilder fun(opts: AvantePromptOptions): AvanteChatMessage[] ---- ----@class AvanteCurlOutput: {url: string, body: table | string, headers: table} ----@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput ---- ----@class ResponseParser ----@field on_chunk fun(chunk: string): any ----@field on_complete fun(err: string|nil): any ----@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil ---- ----@class AvanteDefaultBaseProvider ----@field endpoint string ----@field local? boolean ---- ----@class AvanteSupportedProvider: AvanteDefaultBaseProvider ----@field model string ----@field temperature number ----@field max_tokens number ---- ----@class AvanteAzureProvider: AvanteDefaultBaseProvider ----@field deployment string ----@field api_version string ----@field temperature number ----@field max_tokens number ---- ----@class AvanteCopilotProvider: AvanteSupportedProvider ----@field proxy string | nil ----@field allow_insecure boolean ----@field timeout number ---- ----@class AvanteGeminiProvider: AvanteDefaultBaseProvider ----@field model string ----@field type string ----@field options table ---- ----@class AvanteProvider: AvanteDefaultBaseProvider ----@field model? string ----@field api_key_name string ----@field parse_response_data AvanteResponseParser ----@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput ----@field parse_stream_data? fun(line: string, handler_opts: AvanteHandlerOptions): nil ---- ----@alias AvanteChunkParser fun(chunk: string): any ----@alias AvanteCompleteParser fun(err: string|nil): nil - -------------------------------Anthropic------------------------------ - ----@param opts AvantePromptOptions ----@return AvanteClaudeMessage[] -M.make_claude_message = function(opts) - local code_prompt_obj = { - type = "text", - text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), - } - - if Tiktoken.count(code_prompt_obj.text) > 1024 then - code_prompt_obj.cache_control = { type = "ephemeral" } - end - - if opts.selected_code_content then - code_prompt_obj.text = string.format("```%s\n%s```", opts.code_lang, opts.code_content) - end - - local message_content = { - code_prompt_obj, - } - - if opts.selected_code_content then - local selected_code_obj = { - type = "text", - text = string.format("```%s\n%s```", opts.code_lang, opts.selected_code_content), - } - - if Tiktoken.count(selected_code_obj.text) > 1024 then - selected_code_obj.cache_control = { type = "ephemeral" } - end - - table.insert(message_content, selected_code_obj) - end - - table.insert(message_content, { - type = "text", - text = string.format("%s", opts.question), - }) - - local user_prompt = base_user_prompt - - local user_prompt_obj = { - type = "text", - text = user_prompt, - } - - if Tiktoken.count(user_prompt_obj.text) > 1024 then - user_prompt_obj.cache_control = { type = "ephemeral" } - end - - table.insert(message_content, user_prompt_obj) - - return { - { - role = "user", - content = message_content, - }, - } -end - ----@type AvanteResponseParser -M.parse_claude_response = function(data_stream, event_state, opts) - if event_state == "content_block_delta" then - local ok, json = pcall(vim.json.decode, data_stream) - if not ok then - return - end - opts.on_chunk(json.delta.text) - elseif event_state == "message_stop" then - opts.on_complete(nil) - return - elseif event_state == "error" then - opts.on_complete(vim.json.decode(data_stream)) - end -end - ----@type AvanteCurlArgsBuilder -M.make_claude_curl_args = function(code_opts) - return { - url = Utils.trim(Config.claude.endpoint, { suffix = "/" }) .. "/v1/messages", - headers = { - ["Content-Type"] = "application/json", - ["x-api-key"] = E.value("claude"), - ["anthropic-version"] = "2023-06-01", - ["anthropic-beta"] = "prompt-caching-2024-07-31", - }, - body = { - model = Config.claude.model, - system = system_prompt, - stream = true, - messages = M.make_claude_message(code_opts), - temperature = Config.claude.temperature, - max_tokens = Config.claude.max_tokens, - }, - } -end - -------------------------------OpenAI------------------------------ - ----@param opts AvantePromptOptions ----@return AvanteOpenAIMessage[] -M.make_openai_message = function(opts) - local user_prompt = base_user_prompt - .. "\n\nCODE:\n" - .. "```" - .. opts.code_lang - .. "\n" - .. opts.code_content - .. "\n```" - .. "\n\nQUESTION:\n" - .. opts.question - - if opts.selected_code_content ~= nil then - user_prompt = base_user_prompt - .. "\n\nCODE CONTEXT:\n" - .. "```" - .. opts.code_lang - .. "\n" - .. opts.code_content - .. "\n```" - .. "\n\nCODE:\n" - .. "```" - .. opts.code_lang - .. "\n" - .. opts.selected_code_content - .. "\n```" - .. "\n\nQUESTION:\n" - .. opts.question - end - - return { - { role = "system", content = system_prompt }, - { role = "user", content = user_prompt }, - } -end - ----@type AvanteResponseParser -M.parse_openai_response = function(data_stream, _, opts) - if data_stream:match('"%[DONE%]":') then - opts.on_complete(nil) - return - end - if data_stream:match('"delta":') then - local json = vim.json.decode(data_stream) - if json.choices and json.choices[1] then - local choice = json.choices[1] - if choice.finish_reason == "stop" then - opts.on_complete(nil) - elseif choice.delta.content then - opts.on_chunk(choice.delta.content) - end - end - end -end - ----@type AvanteCurlArgsBuilder -M.make_openai_curl_args = function(code_opts) - return { - url = Utils.trim(Config.openai.endpoint, { suffix = "/" }) .. "/v1/chat/completions", - headers = { - ["Content-Type"] = "application/json", - ["Authorization"] = "Bearer " .. E.value("openai"), - }, - body = { - model = Config.openai.model, - messages = M.make_openai_message(code_opts), - temperature = Config.openai.temperature, - max_tokens = Config.openai.max_tokens, - stream = true, - }, - } -end - -------------------------------Copilot------------------------------ ----@type AvanteAiMessageBuilder -M.make_copilot_message = M.make_openai_message - ----@type AvanteResponseParser -M.parse_copilot_response = M.parse_openai_response - ----@type AvanteCurlArgsBuilder -M.make_copilot_curl_args = function(code_opts) - local github_token = Utils.copilot.cached_token() - - if not github_token then - error( - "No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim`" - ) - end - - local on_done = function() - return { - url = Utils.trim(Config.copilot.endpoint, { suffix = "/" }) .. "/chat/completions", - proxy = Config.copilot.proxy, - insecure = Config.copilot.allow_insecure, - headers = Utils.copilot.generate_headers(M.copilot.token.token, M.copilot.sessionid, M.copilot.machineid), - body = { - mode = Config.copilot.model, - n = 1, - top_p = 1, - stream = true, - temperature = Config.copilot.temperature, - max_tokens = Config.copilot.max_tokens, - messages = M.make_copilot_message(code_opts), - }, - } - end - - local result = nil - - if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then - local sessionid = Utils.copilot.uuid() .. tostring(math.floor(os.time() * 1000)) - - local url = "https://api.github.com/copilot_internal/v2/token" - local headers = { - ["Authorization"] = "token " .. github_token, - ["Accept"] = "application/json", - } - for key, value in pairs(Utils.copilot.version_headers) do - headers[key] = value - end - - local response = curl.get(url, { - timeout = Config.copilot.timeout, - headers = headers, - proxy = M.copilot.proxy, - insecure = M.copilot.allow_insecure, - on_error = function(err) - error("Failed to get response: " .. vim.inspect(err)) - end, - }) - - M.copilot.sessionid = sessionid - M.copilot.token = vim.json.decode(response.body) - result = on_done() - else - result = on_done() - end - - return result -end - -------------------------------Azure------------------------------ - ----@type AvanteAiMessageBuilder -M.make_azure_message = M.make_openai_message - ----@type AvanteResponseParser -M.parse_azure_response = M.parse_openai_response - ----@type AvanteCurlArgsBuilder -M.make_azure_curl_args = function(code_opts) - return { - url = Config.azure.endpoint - .. "/openai/deployments/" - .. Config.azure.deployment - .. "/chat/completions?api-version=" - .. Config.azure.api_version, - headers = { - ["Content-Type"] = "application/json", - ["api-key"] = E.value("azure"), - }, - body = { - messages = M.make_openai_message(code_opts), - temperature = Config.azure.temperature, - max_tokens = Config.azure.max_tokens, - stream = true, - }, - } -end - -------------------------------Deepseek------------------------------ - ----@type AvanteAiMessageBuilder -M.make_deepseek_message = M.make_openai_message - ----@type AvanteResponseParser -M.parse_deepseek_response = M.parse_openai_response - ----@type AvanteCurlArgsBuilder -M.make_deepseek_curl_args = function(code_opts) - return { - url = Utils.trim(Config.deepseek.endpoint, { suffix = "/" }) .. "/chat/completions", - headers = { - ["Content-Type"] = "application/json", - ["Authorization"] = "Bearer " .. E.value("deepseek"), - }, - body = { - model = Config.deepseek.model, - messages = M.make_openai_message(code_opts), - temperature = Config.deepseek.temperature, - max_tokens = Config.deepseek.max_tokens, - stream = true, - }, - } -end - -------------------------------Grok------------------------------ - ----@type AvanteAiMessageBuilder -M.make_groq_message = M.make_openai_message - ----@type AvanteResponseParser -M.parse_groq_response = M.parse_openai_response - ----@type AvanteCurlArgsBuilder -M.make_groq_curl_args = function(code_opts) - return { - url = Utils.trim(Config.groq.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions", - headers = { - ["Content-Type"] = "application/json", - ["Authorization"] = "Bearer " .. E.value("groq"), - }, - body = { - model = Config.groq.model, - messages = M.make_openai_message(code_opts), - temperature = Config.groq.temperature, - max_tokens = Config.groq.max_tokens, - stream = true, - }, - } -end - -------------------------------Gemini------------------------------ - ----@param opts AvantePromptOptions ----@return AvanteGeminiMessage[] -M.make_gemini_message = function(opts) - local code_prompt_obj = { - text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), - } - - if opts.selected_code_content then - code_prompt_obj.text = string.format("```%s\n%s```", opts.code_lang, opts.code_content) - end - - -- parts ready - local message_content = { - code_prompt_obj, - } - - if opts.selected_code_content then - local selected_code_obj = { - text = string.format("```%s\n%s```", opts.code_lang, opts.selected_code_content), - } - - table.insert(message_content, selected_code_obj) - end - - -- insert a part into parts - table.insert(message_content, { - text = string.format("%s", opts.question), - }) - - -- local user_prompt_obj = { - -- text = base_user_prompt, - -- } - - -- insert another part into parts - -- table.insert(message_content, user_prompt_obj) - - return { - { - role = "user", - parts = message_content, - }, - } -end - ----@type AvanteResponseParser -M.parse_gemini_response = function(data_stream, event_state, opts) - local json = vim.json.decode(data_stream) - opts.on_chunk(json.candidates[1].content.parts[1].text) -end - ----@type AvanteCurlArgsBuilder -M.make_gemini_curl_args = function(code_opts) - local endpoint = "" - if Config.gemini.endpoint == "" then - endpoint = "https://generativelanguage.googleapis.com/v1beta/models/" - .. Config.gemini.model - .. ":streamGenerateContent?alt=sse&key=" - .. E.value("gemini") - end - -- Prepare the body with contents and options (only if options are not empty) - local body = { - systemInstruction = { - role = "user", - parts = { - { - text = system_prompt .. base_user_prompt, - }, - }, - }, - contents = M.make_gemini_message(code_opts), - } - if next(Config.gemini.options) ~= nil then -- Check if options table is not empty - for k, v in pairs(Config.gemini.options) do - body[k] = v - end - end - return { - url = endpoint, - headers = { - ["Content-Type"] = "application/json", - }, - body = body, - } -end - -------------------------------Logic------------------------------ - local group = vim.api.nvim_create_augroup("AvanteLLM", { clear = true }) local active_job = nil @@ -747,30 +75,29 @@ local active_job = nil M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete) local provider = Config.provider + ---@type AvantePromptOptions local code_opts = { + base_prompt = base_user_prompt, + system_prompt = system_prompt, question = question, code_lang = code_lang, code_content = code_content, selected_code_content = selected_content_content, } + + ---@type string local current_event_state = nil + + ---@type AvanteProviderFunctor + local Provider = P[provider] + + ---@type AvanteHandlerOptions local handler_opts = { on_chunk = on_chunk, on_complete = on_complete } - ---@type AvanteCurlOutput - local spec = nil - - ---@type AvanteProvider - local ProviderConfig = nil - - if E.is_default(provider) then - spec = M["make_" .. provider .. "_curl_args"](code_opts) - else - ProviderConfig = Config.vendors[provider] - spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts) - end + local spec = Provider.parse_curl_args(Config.get_provider(provider), code_opts) ---@param line string - local function parse_and_call(line) + local function parse_stream_data(line) local event = line:match("^event: (.+)$") if event then current_event_state = event @@ -778,11 +105,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content, end local data_match = line:match("^data: (.+)$") if data_match then - if ProviderConfig ~= nil then - ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts) - else - M["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts) - end + Provider.parse_response(data_match, current_event_state, handler_opts) end end @@ -793,6 +116,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content, active_job = curl.post(spec.url, { headers = spec.headers, + proxy = spec.proxy, + insecure = spec.insecure, body = vim.json.encode(spec.body), stream = function(err, data, _) if err then @@ -803,16 +128,16 @@ M.stream = function(question, code_lang, code_content, selected_content_content, return end vim.schedule(function() - if ProviderConfig ~= nil and ProviderConfig.parse_stream_data ~= nil then - if ProviderConfig.parse_response_data ~= nil then + if Config.options[provider] == nil and Provider.parse_stream_data ~= nil then + if Provider.parse_response ~= nil then Utils.warn( "parse_stream_data and parse_response_data are mutually exclusive, and thus parse_response_data will be ignored. Make sure that you handle the incoming data correctly.", { once = true } ) end - ProviderConfig.parse_stream_data(data, handler_opts) + Provider.parse_stream_data(data, handler_opts) else - parse_and_call(data) + parse_stream_data(data) end end) end, @@ -839,61 +164,4 @@ M.stream = function(question, code_lang, code_content, selected_content_content, return active_job end ----@public -function M.setup() - if Config.provider == "copilot" and not M.copilot then - M.copilot = { - proxy = Config.copilot.proxy, - allow_insecure = Config.copilot.allow_insecure, - github_token = Utils.copilot.cached_token(), - sessionid = nil, - token = nil, - machineid = Utils.copilot.machine_id(), - } - end - - local has = E[Config.provider] - if not has then - E.setup(E.key()) - end - - M.commands() -end - ----@param provider Provider -function M.refresh(provider) - local has = E[provider] - if not has then - E.setup(E.key(provider), true) - else - Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) - end - require("avante.config").override({ provider = provider }) -end - ----@private -M.commands = function() - api.nvim_create_user_command("AvanteSwitchProvider", function(args) - local cmd = vim.trim(args.args or "") - M.refresh(cmd) - end, { - nargs = 1, - desc = "avante: switch provider", - complete = function(_, line) - if line:match("^%s*AvanteSwitchProvider %w") then - return {} - end - local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or "" - -- join two tables - local Keys = vim.list_extend(vim.tbl_keys(E.env), vim.tbl_keys(Config.vendors)) - return vim.tbl_filter(function(key) - return key:find(prefix) == 1 - end, Keys) - end, - }) -end - -M.SYSTEM_PROMPT = system_prompt -M.BASE_PROMPT = base_user_prompt - return M diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua new file mode 100644 index 0000000..1088655 --- /dev/null +++ b/lua/avante/providers/azure.lua @@ -0,0 +1,44 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") +local O = require("avante.providers").openai + +---@class AvanteProviderFunctor +local M = {} + +M.API_KEY = "AZURE_OPENAI_API_KEY" + +M.has = function() + return os.getenv(M.API_KEY) and true or false +end + +M.parse_message = O.parse_message +M.parse_response = O.parse_response + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + local headers = { + ["Content-Type"] = "application/json", + } + if not P.env.is_local("azure") then + headers["api-key"] = os.getenv(M.API_KEY) + end + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) + .. "/openai/deployments/" + .. base.deployment + .. "/chat/completions?api-version=" + .. base.api_version, + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + messages = M.parse_message(code_opts), + stream = true, + }, body_opts), + } +end + +return M diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua new file mode 100644 index 0000000..8cd770e --- /dev/null +++ b/lua/avante/providers/claude.lua @@ -0,0 +1,112 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local Tiktoken = require("avante.tiktoken") +local P = require("avante.providers") + +---@class AvanteProviderFunctor +local M = {} + +M.API_KEY = "ANTHROPIC_API_KEY" + +M.has = function() + return os.getenv(M.API_KEY) and true or false +end + +M.parse_message = function(opts) + local code_prompt_obj = { + type = "text", + text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), + } + + if Tiktoken.count(code_prompt_obj.text) > 1024 then + code_prompt_obj.cache_control = { type = "ephemeral" } + end + + if opts.selected_code_content then + code_prompt_obj.text = string.format("```%s\n%s```", opts.code_lang, opts.code_content) + end + + local message_content = { + code_prompt_obj, + } + + if opts.selected_code_content then + local selected_code_obj = { + type = "text", + text = string.format("```%s\n%s```", opts.code_lang, opts.selected_code_content), + } + + if Tiktoken.count(selected_code_obj.text) > 1024 then + selected_code_obj.cache_control = { type = "ephemeral" } + end + + table.insert(message_content, selected_code_obj) + end + + table.insert(message_content, { + type = "text", + text = string.format("%s", opts.question), + }) + + local user_prompt = opts.base_prompt + + local user_prompt_obj = { + type = "text", + text = user_prompt, + } + + if Tiktoken.count(user_prompt_obj.text) > 1024 then + user_prompt_obj.cache_control = { type = "ephemeral" } + end + + table.insert(message_content, user_prompt_obj) + + return { + { + role = "user", + content = message_content, + }, + } +end + +M.parse_response = function(data_stream, event_state, opts) + if event_state == "content_block_delta" then + local ok, json = pcall(vim.json.decode, data_stream) + if not ok then + return + end + opts.on_chunk(json.delta.text) + elseif event_state == "message_stop" then + opts.on_complete(nil) + return + elseif event_state == "error" then + opts.on_complete(vim.json.decode(data_stream)) + end +end + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + local headers = { + ["Content-Type"] = "application/json", + ["anthropic-version"] = "2023-06-01", + ["anthropic-beta"] = "prompt-caching-2024-07-31", + } + if not P.env.is_local("claude") then + headers["x-api-key"] = os.getenv(M.API_KEY) + end + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + model = base.model, + messages = M.parse_message(code_opts), + stream = true, + }, body_opts), + } +end + +return M diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua new file mode 100644 index 0000000..af6374a --- /dev/null +++ b/lua/avante/providers/copilot.lua @@ -0,0 +1,229 @@ +local curl = require("plenary.curl") + +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") +local O = require("avante.providers").openai + +---@class AvanteProviderFunctor +local M = {} + +---@class CopilotToken +---@field annotations_enabled boolean +---@field chat_enabled boolean +---@field chat_jetbrains_enabled boolean +---@field code_quote_enabled boolean +---@field codesearch boolean +---@field copilotignore_enabled boolean +---@field endpoints {api: string, ["origin-tracker"]: string, proxy: string, telemetry: string} +---@field expires_at integer +---@field individual boolean +---@field nes_enabled boolean +---@field prompt_8k boolean +---@field public_suggestions string +---@field refresh_in integer +---@field sku string +---@field snippy_load_test_enabled boolean +---@field telemetry string +---@field token string +---@field tracking_id string +---@field vsc_electron_fetcher boolean +---@field xcode boolean +---@field xcode_chat boolean +--- +---@private +---@class AvanteCopilot: table +---@field token? CopilotToken +---@field github_token? string +---@field sessionid? string +---@field machineid? string +M.copilot = nil + +local H = {} + +local version_headers = { + ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, + ["editor-plugin-version"] = "avante.nvim/0.0.0", + ["user-agent"] = "avante.nvim/0.0.0", +} + +---@return string +H.uuid = function() + local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" + return ( + string.gsub(template, "[xy]", function(c) + local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb) + return string.format("%x", v) + end) + ) +end + +---@return string +H.machine_id = function() + local length = 65 + local hex_chars = "0123456789abcdef" + local hex = "" + for _ = 1, length do + hex = hex .. hex_chars:sub(math.random(1, #hex_chars), math.random(1, #hex_chars)) + end + return hex +end + +---@return string | nil +H.find_config_path = function() + local config = vim.fn.expand("$XDG_CONFIG_HOME") + if config and vim.fn.isdirectory(config) > 0 then + return config + elseif vim.fn.has("win32") > 0 then + config = vim.fn.expand("~/AppData/Local") + if vim.fn.isdirectory(config) > 0 then + return config + end + else + config = vim.fn.expand("~/.config") + if vim.fn.isdirectory(config) > 0 then + return config + end + end +end + +H.cached_token = function() + -- loading token from the environment only in GitHub Codespaces + local token = os.getenv("GITHUB_TOKEN") + local codespaces = os.getenv("CODESPACES") + if token and codespaces then + return token + end + + -- loading token from the file + local config_path = H.find_config_path() + if not config_path then + return nil + end + + -- token can be sometimes in apps.json sometimes in hosts.json + local file_paths = { + config_path .. "/github-copilot/hosts.json", + config_path .. "/github-copilot/apps.json", + } + + for _, file_path in ipairs(file_paths) do + if vim.fn.filereadable(file_path) == 1 then + local userdata = vim.fn.json_decode(vim.fn.readfile(file_path)) + for key, value in pairs(userdata) do + if string.find(key, "github.com") then + return value.oauth_token + end + end + end + end + + return nil +end + +---@param token string +---@param sessionid string +---@param machineid string +---@return table +H.generate_headers = function(token, sessionid, machineid) + local headers = { + ["authorization"] = "Bearer " .. token, + ["x-request-id"] = H.uuid(), + ["vscode-sessionid"] = sessionid, + ["vscode-machineid"] = machineid, + ["copilot-integration-id"] = "vscode-chat", + ["openai-organization"] = "github-copilot", + ["openai-intent"] = "conversation-panel", + ["content-type"] = "application/json", + } + for key, value in pairs(version_headers) do + headers[key] = value + end + return headers +end + +M.API_KEY = P.AVANTE_INTERNAL_KEY + +M.has = function() + if Utils.has("copilot.lua") or Utils.has("copilot.vim") or H.find_config_path() then + return true + end + Utils.warn("copilot is not setup correctly. Please use copilot.lua or copilot.vim for authentication.") + return false +end + +M.parse_message = O.parse_message +M.parse_response = O.parse_response + +M.parse_curl_args = function(provider, code_opts) + local github_token = H.cached_token() + + if not github_token then + error( + "No GitHub token found, please use `:Copilot auth` to setup with `copilot.lua` or `:Copilot setup` with `copilot.vim`" + ) + end + local base, body_opts = P.parse_config(provider) + + local on_done = function() + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = H.generate_headers(M.copilot.token.token, M.copilot.sessionid, M.copilot.machineid), + body = vim.tbl_deep_extend("force", { + mode = base.model, + n = 1, + top_p = 1, + stream = true, + messages = M.parse_message(code_opts), + }, body_opts), + } + end + + local result = nil + + if not M.copilot.token or (M.copilot.token.expires_at and M.copilot.token.expires_at <= math.floor(os.time())) then + local sessionid = H.uuid() .. tostring(math.floor(os.time() * 1000)) + + local url = "https://api.github.com/copilot_internal/v2/token" + local headers = { + ["Authorization"] = "token " .. github_token, + ["Accept"] = "application/json", + } + for key, value in pairs(version_headers) do + headers[key] = value + end + + local response = curl.get(url, { + timeout = Config.copilot.timeout, + headers = headers, + proxy = base.proxy, + insecure = base.allow_insecure, + on_error = function(err) + error("Failed to get response: " .. vim.inspect(err)) + end, + }) + + M.copilot.sessionid = sessionid + M.copilot.token = vim.json.decode(response.body) + result = on_done() + else + result = on_done() + end + + return result +end + +M.setup = function() + if not M.copilot then + M.copilot = { + sessionid = nil, + token = nil, + github_token = H.cached_token(), + machineid = H.machine_id(), + } + end +end + +return M diff --git a/lua/avante/providers/deepseek.lua b/lua/avante/providers/deepseek.lua new file mode 100644 index 0000000..effef50 --- /dev/null +++ b/lua/avante/providers/deepseek.lua @@ -0,0 +1,41 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") +local O = require("avante.providers").openai + +---@class AvanteProviderFunctor +local M = {} + +M.API_KEY = "DEEPSEEK_API_KEY" + +M.has = function() + return os.getenv(M.API_KEY) and true or false +end + +M.parse_message = O.parse_message +M.parse_response = O.parse_response + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + local headers = { + ["Content-Type"] = "application/json", + } + if not P.env.is_local("deepseek") then + headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY) + end + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/chat/completions", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + model = base.model, + messages = M.parse_message(code_opts), + stream = true, + }, body_opts), + } +end + +return M diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua new file mode 100644 index 0000000..5aee0e6 --- /dev/null +++ b/lua/avante/providers/gemini.lua @@ -0,0 +1,79 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") + +---@class AvanteProviderFunctor +local M = {} + +M.API_KEY = "GROQ_API_KEY" + +M.has = function() + return os.getenv(M.API_KEY) and true or false +end + +M.parse_message = function(opts) + local code_prompt_obj = { + text = string.format("```%s\n%s```", opts.code_lang, opts.code_content), + } + + if opts.selected_code_content then + code_prompt_obj.text = string.format("```%s\n%s```", opts.code_lang, opts.code_content) + end + + -- parts ready + local message_content = { + code_prompt_obj, + } + + if opts.selected_code_content then + local selected_code_obj = { + text = string.format("```%s\n%s```", opts.code_lang, opts.selected_code_content), + } + + table.insert(message_content, selected_code_obj) + end + + -- insert a part into parts + table.insert(message_content, { + text = string.format("%s", opts.question), + }) + + return { + systemInstruction = { + role = "user", + parts = { + { + text = opts.system_prompt .. "\n" .. opts.base_prompt, + }, + }, + }, + contents = { + { + role = "user", + parts = message_content, + }, + }, + } +end +M.parse_response = function(data_stream, _, opts) + local json = vim.json.decode(data_stream) + opts.on_chunk(json.candidates[1].content.parts[1].text) +end + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) + .. "/" + .. base.model + .. ":streamGenerateContent?alt=sse&key=" + .. os.getenv(M.API_KEY), + proxy = base.proxy, + insecure = base.allow_insecure, + headers = { ["Content-Type"] = "application/json" }, + body = vim.tbl_deep_extend("force", {}, M.parse_message(code_opts), body_opts), + } +end + +return M diff --git a/lua/avante/providers/groq.lua b/lua/avante/providers/groq.lua new file mode 100644 index 0000000..779d932 --- /dev/null +++ b/lua/avante/providers/groq.lua @@ -0,0 +1,41 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") +local O = require("avante.providers").openai + +---@class AvanteProviderFunctor +local M = {} + +M.API_KEY = "GROQ_API_KEY" + +M.has = function() + return os.getenv(M.API_KEY) and true or false +end + +M.parse_message = O.parse_message +M.parse_response = O.parse_response + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + local headers = { + ["Content-Type"] = "application/json", + } + if not P.env.is_local("groq") then + headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY) + end + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/openai/v1/chat/completions", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + model = base.model, + messages = M.parse_message(code_opts), + stream = true, + }, body_opts), + } +end + +return M diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua new file mode 100644 index 0000000..4e3a28e --- /dev/null +++ b/lua/avante/providers/init.lua @@ -0,0 +1,282 @@ +local api = vim.api + +local Config = require("avante.config") +local Utils = require("avante.utils") +local Dressing = require("avante.ui.dressing") + +---@class AvanteHandlerOptions: table<[string], string> +---@field on_chunk AvanteChunkParser +---@field on_complete AvanteCompleteParser +--- +---@class AvantePromptOptions: table<[string], string> +---@field base_prompt AvanteBasePrompt +---@field system_prompt AvanteSystemPrompt +---@field question string +---@field code_lang string +---@field code_content string +---@field selected_code_content? string +--- +---@class AvanteBaseMessage +---@field role "user" | "system" +---@field content string +--- +---@class AvanteClaudeMessage: AvanteBaseMessage +---@field role "user" +---@field content {type: "text", text: string, cache_control?: {type: "ephemeral"}}[] +--- +---@class AvanteGeminiMessage +---@field role "user" +---@field parts { text: string }[] +--- +---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage +--- +---@alias AvanteMessageParser fun(opts: AvantePromptOptions): AvanteChatMessage[] +--- +---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table | string, headers: table} +---@alias AvanteCurlArgsParser fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput +--- +---@class ResponseParser +---@field on_chunk fun(chunk: string): any +---@field on_complete fun(err: string|nil): any +---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil +--- +---@class AvanteDefaultBaseProvider: table +---@field endpoint? string +---@field model? string +---@field local? boolean +---@field proxy? string +---@field allow_insecure? boolean +--- +---@class AvanteSupportedProvider: AvanteDefaultBaseProvider +---@field temperature? number +---@field max_tokens? number +--- +---@class AvanteAzureProvider: AvanteDefaultBaseProvider +---@field deployment string +---@field api_version string +---@field temperature number +---@field max_tokens number +--- +---@class AvanteCopilotProvider: AvanteSupportedProvider +---@field timeout number +--- +---@class AvanteGeminiProvider: AvanteDefaultBaseProvider +---@field model string +--- +---@class AvanteProvider: AvanteDefaultBaseProvider +---@field api_key_name string +---@field parse_response_data AvanteResponseParser +---@field parse_curl_args AvanteCurlArgsParser +---@field parse_stream_data? AvanteStreamParser +--- +---@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil +---@alias AvanteChunkParser fun(chunk: string): any +---@alias AvanteCompleteParser fun(err: string|nil): nil +---@alias AvanteLLMConfigHandler fun(opts: AvanteSupportedProvider): AvanteDefaultBaseProvider, table +--- +---@class AvanteProviderFunctor +---@field parse_message AvanteMessageParser +---@field parse_response AvanteResponseParser +---@field parse_curl_args AvanteCurlArgsParser +---@field setup? fun(): nil +---@field has fun(): boolean +---@field API_KEY string +---@field parse_stream_data? AvanteStreamParser +--- +---@class avante.Providers +---@field openai AvanteProviderFunctor +---@field copilot AvanteProviderFunctor +---@field claude AvanteProviderFunctor +---@field azure AvanteProviderFunctor +---@field deepseek AvanteProviderFunctor +---@field gemini AvanteProviderFunctor +---@field groq AvanteProviderFunctor +local M = {} + +setmetatable(M, { + ---@param t avante.Providers + ---@param k Provider + __index = function(t, k) + if Config.vendors[k] ~= nil then + ---@type AvanteProvider + local v = Config.vendors[k] + + -- Patch from vendors similar to supported providers. + t[k] = setmetatable({}, { __index = v }) + t[k].API_KEY = v.api_key_name + -- Hack for aliasing and makes it sane for us. + t[k].parse_response = v.parse_response_data + t[k].has = function() + return os.getenv(v.api_key_name) and true or false + end + + return t[k] + end + + ---@type AvanteProviderFunctor + t[k] = require("avante.providers." .. k) + return t[k] + end, +}) + +---@class EnvironmentHandler +local E = {} + +---@private +E._once = false + +--- intialize the environment variable for current neovim session. +--- This will only run once and spawn a UI for users to input the envvar. +---@param opts {refresh: boolean, provider: AvanteProviderFunctor} +---@private +E.setup = function(opts) + local var = opts.provider.API_KEY + + if var == M.AVANTE_INTERNAL_KEY then + return + end + + local refresh = opts.refresh or false + + ---@param value string + ---@return nil + local function on_confirm(value) + if value then + vim.fn.setenv(var, value) + else + if not opts.provider.has() then + Utils.warn("Failed to set " .. var .. ". Avante won't work as expected", { once = true, title = "Avante" }) + end + end + end + + if refresh then + vim.defer_fn(function() + Dressing.initialize_input_buffer({ opts = { prompt = "Enter " .. var .. ": " }, on_confirm = on_confirm }) + end, 200) + elseif not E._once then + E._once = true + api.nvim_create_autocmd({ "BufEnter", "BufWinEnter", "WinEnter" }, { + pattern = "*", + once = true, + callback = function() + vim.defer_fn(function() + -- only mount if given buffer is not of buftype ministarter, dashboard, alpha, qf + local exclude_buftypes = { "qf", "nofile" } + local exclude_filetypes = { + "NvimTree", + "Outline", + "help", + "dashboard", + "alpha", + "qf", + "ministarter", + "TelescopePrompt", + "gitcommit", + "gitrebase", + "DressingInput", + } + if + not vim.tbl_contains(exclude_buftypes, vim.bo.buftype) + and not vim.tbl_contains(exclude_filetypes, vim.bo.filetype) + and not opts.provider.has() + then + Dressing.initialize_input_buffer({ + opts = { prompt = "Enter " .. var .. ": " }, + on_confirm = on_confirm, + }) + end + end, 200) + end, + }) + end +end + +---@param provider Provider +E.is_local = function(provider) + local cur = M.get(provider) + return cur["local"] ~= nil and cur["local"] or false +end + +M.env = E + +M.AVANTE_INTERNAL_KEY = "__avante_env_internal" + +M.setup = function() + ---@type AvanteProviderFunctor + local provider = M[Config.provider] + E.setup({ provider = provider }) + + if provider.setup ~= nil then + provider.setup() + end + + M.commands() +end + +---@private +---@param provider Provider +function M.refresh(provider) + ---@type AvanteProviderFunctor + local p = M[Config.provider] + if not p.has() then + E.setup({ provider = p, refresh = true }) + else + Utils.info("Switch to provider: " .. provider, { once = true, title = "Avante" }) + end + require("avante.config").override({ provider = provider }) +end + +local default_providers = { "openai", "claude", "azure", "deepseek", "groq", "gemini", "copilot" } + +---@private +M.commands = function() + api.nvim_create_user_command("AvanteSwitchProvider", function(args) + local cmd = vim.trim(args.args or "") + M.refresh(cmd) + end, { + nargs = 1, + desc = "avante: switch provider", + complete = function(_, line) + if line:match("^%s*AvanteSwitchProvider %w") then + return {} + end + local prefix = line:match("^%s*AvanteSwitchProvider (%w*)") or "" + -- join two tables + local Keys = vim.list_extend(default_providers, vim.tbl_keys(Config.vendors or {})) + return vim.tbl_filter(function(key) + return key:find(prefix) == 1 + end, Keys) + end, + }) +end + +---@param opts AvanteProvider | AvanteSupportedProvider +---@return AvanteDefaultBaseProvider, table +M.parse_config = function(opts) + ---@type AvanteDefaultBaseProvider + local s1 = {} + ---@type table + local s2 = {} + + for key, value in pairs(opts) do + if vim.tbl_contains(Config.BASE_PROVIDER_KEYS, key) then + s1[key] = value + else + s2[key] = value + end + end + + return s1, vim.tbl_filter(function(it) + return type(it) ~= "function" + end, s2) +end + +---@private +---@param provider Provider +M.get = function(provider) + local cur = Config.get_provider(provider or Config.provider) + return type(cur) == "function" and cur() or cur +end + +return M diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua new file mode 100644 index 0000000..d2a3dc7 --- /dev/null +++ b/lua/avante/providers/openai.lua @@ -0,0 +1,110 @@ +local Utils = require("avante.utils") +local Config = require("avante.config") +local P = require("avante.providers") + +---@class OpenAIChatResponse +---@field id string +---@field object "chat.completion" | "chat.completion.chunk" +---@field created integer +---@field model string +---@field system_fingerprint string +---@field choices? OpenAIResponseChoice[] +---@field usage {prompt_tokens: integer, completion_tokens: integer, total_tokens: integer} +--- +---@class OpenAIResponseChoice +---@field index integer +---@field delta OpenAIMessage +---@field logprobs? integer +---@field finish_reason? "stop" | "length" +--- +---@class OpenAIMessage +---@field role? "user" | "system" | "assistant" +---@field content string +--- +---@class AvanteProviderFunctor +local M = {} + +M.API_KEY = "OPENAI_API_KEY" + +M.has = function() + return os.getenv(M.API_KEY) and true or false +end + +M.parse_message = function(opts) + local user_prompt = opts.base_prompt + .. "\n\nCODE:\n" + .. "```" + .. opts.code_lang + .. "\n" + .. opts.code_content + .. "\n```" + .. "\n\nQUESTION:\n" + .. opts.question + + if opts.selected_code_content ~= nil then + user_prompt = opts.base_prompt + .. "\n\nCODE CONTEXT:\n" + .. "```" + .. opts.code_lang + .. "\n" + .. opts.code_content + .. "\n```" + .. "\n\nCODE:\n" + .. "```" + .. opts.code_lang + .. "\n" + .. opts.selected_code_content + .. "\n```" + .. "\n\nQUESTION:\n" + .. opts.question + end + + return { + { role = "system", content = opts.system_prompt }, + { role = "user", content = user_prompt }, + } +end + +M.parse_response = function(data_stream, _, opts) + if data_stream:match('"%[DONE%]":') then + opts.on_complete(nil) + return + end + if data_stream:match('"delta":') then + ---@type OpenAIChatResponse + local json = vim.json.decode(data_stream) + if json.choices and json.choices[1] then + local choice = json.choices[1] + if choice.finish_reason == "stop" then + opts.on_complete(nil) + elseif choice.delta.content then + opts.on_chunk(choice.delta.content) + end + end + end +end + +M.parse_curl_args = function(provider, code_opts) + local base, body_opts = P.parse_config(provider) + + local headers = { + ["Content-Type"] = "application/json", + } + if not P.env.is_local("openai") then + headers["Authorization"] = "Bearer " .. os.getenv(M.API_KEY) + end + + return { + url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/chat/completions", + proxy = base.proxy, + insecure = base.allow_insecure, + headers = headers, + body = vim.tbl_deep_extend("force", { + model = base.model, + messages = M.parse_message(code_opts), + stream = true, + }, body_opts), + } +end + +return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index c8802ad..2005063 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -472,7 +472,7 @@ end function Sidebar:on_mount() self:refresh_winids() - api.nvim_set_option_value("wrap", Config.windows.wrap_line, { win = self.result.winid }) + api.nvim_set_option_value("wrap", Config.windows.wrap, { win = self.result.winid }) local current_apply_extmark_id = nil @@ -1109,7 +1109,7 @@ function Sidebar:create_input() width = win_width - 2, -- Subtract the width of the input box borders }, }, { - prompt = "> ", + prompt = Config.windows.prompt.prefix, default_value = "", on_submit = function(user_input) if user_input == "" then diff --git a/lua/avante/utils/copilot.lua b/lua/avante/utils/copilot.lua deleted file mode 100644 index 3dd1c79..0000000 --- a/lua/avante/utils/copilot.lua +++ /dev/null @@ -1,109 +0,0 @@ ----This file COPY and MODIFIED based on: https://github.com/CopilotC-Nvim/CopilotChat.nvim/blob/canary/lua/CopilotChat/copilot.lua#L560 - ----@class avante.utils.copilot -local M = {} - -local version_headers = { - ["editor-version"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, - ["editor-plugin-version"] = "avante.nvim/0.0.0", - ["user-agent"] = "avante.nvim/0.0.0", -} - ----@return string -M.uuid = function() - local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" - return ( - string.gsub(template, "[xy]", function(c) - local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb) - return string.format("%x", v) - end) - ) -end - ----@return string -M.machine_id = function() - local length = 65 - local hex_chars = "0123456789abcdef" - local hex = "" - for _ = 1, length do - hex = hex .. hex_chars:sub(math.random(1, #hex_chars), math.random(1, #hex_chars)) - end - return hex -end - ----@return string | nil -M.find_config_path = function() - local config = vim.fn.expand("$XDG_CONFIG_HOME") - if config and vim.fn.isdirectory(config) > 0 then - return config - elseif vim.fn.has("win32") > 0 then - config = vim.fn.expand("~/AppData/Local") - if vim.fn.isdirectory(config) > 0 then - return config - end - else - config = vim.fn.expand("~/.config") - if vim.fn.isdirectory(config) > 0 then - return config - end - end -end - -M.cached_token = function() - -- loading token from the environment only in GitHub Codespaces - local token = os.getenv("GITHUB_TOKEN") - local codespaces = os.getenv("CODESPACES") - if token and codespaces then - return token - end - - -- loading token from the file - local config_path = M.find_config_path() - if not config_path then - return nil - end - - -- token can be sometimes in apps.json sometimes in hosts.json - local file_paths = { - config_path .. "/github-copilot/hosts.json", - config_path .. "/github-copilot/apps.json", - } - - for _, file_path in ipairs(file_paths) do - if vim.fn.filereadable(file_path) == 1 then - local userdata = vim.fn.json_decode(vim.fn.readfile(file_path)) - for key, value in pairs(userdata) do - if string.find(key, "github.com") then - return value.oauth_token - end - end - end - end - - return nil -end - ----@param token string ----@param sessionid string ----@param machineid string ----@return table -M.generate_headers = function(token, sessionid, machineid) - local headers = { - ["authorization"] = "Bearer " .. token, - ["x-request-id"] = M.uuid(), - ["vscode-sessionid"] = sessionid, - ["vscode-machineid"] = machineid, - ["copilot-integration-id"] = "vscode-chat", - ["openai-organization"] = "github-copilot", - ["openai-intent"] = "conversation-panel", - ["content-type"] = "application/json", - } - for key, value in pairs(version_headers) do - headers[key] = value - end - return headers -end - -M.version_headers = version_headers - -return M diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 62d34d4..8dc790e 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -2,7 +2,6 @@ local api = vim.api ---@class avante.Utils: LazyUtilCore ---@field colors avante.util.colors ----@field copilot avante.utils.copilot local M = {} setmetatable(M, {