diff --git a/lua/avante/ai_bot.lua b/lua/avante/ai_bot.lua new file mode 100644 index 0000000..921ab12 --- /dev/null +++ b/lua/avante/ai_bot.lua @@ -0,0 +1,262 @@ +local M = {} + +local curl = require("plenary.curl") +local utils = require("avante.utils") +local config = require("avante.config") +local tiktoken = require("avante.tiktoken") + +local fn = vim.fn + +local system_prompt = [[ +You are an excellent programming expert. +]] + +local base_user_prompt = [[ +Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: + +1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. + +2. When suggesting modifications: + a. Explain why the change is necessary or beneficial. + b. Provide the exact code snippet to be replaced using this format: + +Replace lines: {{start_line}}-{{end_line}} +```{{language}} +{{suggested_code}} +``` + +3. Crucial guidelines for line numbers: + - The range {{start_line}}-{{end_line}} is INCLUSIVE. Both start_line and end_line are included in the replacement. + - Count EVERY line, including empty lines, comments, and the LAST line of the file. + - For single-line changes, use the same number for start and end lines. + - For multi-line changes, ensure the range covers ALL affected lines, from the very first to the very last. + - Include the entire block (e.g., complete function) when modifying structured code. + - Pay special attention to the start_line, ensuring it's not omitted or incorrectly set. + - Double-check that your start_line is correct, especially for changes at the beginning of the file. + - Also, be careful with the end_line, especially when it's the last line of the file. + - Double-check that your line numbers align perfectly with the original code structure. + +4. Context and verification: + - Show 1-2 unchanged lines before and after each modification as context. + - These context lines are NOT included in the replacement range. + - After each suggestion, recount the lines to verify the accuracy of your line numbers. + - Double-check that both the start_line and end_line are correct for each modification. + - Verify that your suggested changes align perfectly with the original code structure. + +5. Final check: + - Review all suggestions, ensuring each line number is correct, especially the start_line and end_line. + - Pay extra attention to the start_line of each modification, ensuring it hasn't shifted down. + - Confirm that no unrelated code is accidentally modified or deleted. + - Verify that the start_line and end_line correctly include all intended lines for replacement. + - If a modification involves the first or last line of the file, explicitly state this in your explanation. + - Perform a final alignment check to ensure your line numbers haven't shifted, especially the start_line. + - Double-check that your line numbers align perfectly with the original code structure. + - Do not show the content after these modifications. + +Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. +]] + +local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) + local api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key then + error("ANTHROPIC_API_KEY environment variable is not set") + end + + local user_prompt = base_user_prompt + + local tokens = config.get().claude.max_tokens + local headers = { + ["Content-Type"] = "application/json", + ["x-api-key"] = api_key, + ["anthropic-version"] = "2023-06-01", + ["anthropic-beta"] = "prompt-caching-2024-07-31", + } + + local code_prompt_obj = { + type = "text", + text = string.format("```%s\n%s```", code_lang, code_content), + } + + local user_prompt_obj = { + type = "text", + text = user_prompt, + } + + if tiktoken.count(code_prompt_obj.text) > 1024 then + code_prompt_obj.cache_control = { type = "ephemeral" } + end + + if tiktoken.count(user_prompt_obj.text) > 1024 then + user_prompt_obj.cache_control = { type = "ephemeral" } + end + + local body = { + model = config.get().claude.model, + system = system_prompt, + messages = { + { + role = "user", + content = { + code_prompt_obj, + { + type = "text", + text = string.format("%s", question), + }, + user_prompt_obj, + }, + }, + }, + stream = true, + temperature = config.get().claude.temperature, + max_tokens = tokens, + } + + local url = utils.trim_suffix(config.get().claude.endpoint, "/") .. "/v1/messages" + + print("Sending request to Claude API...") + + curl.post(url, { + ---@diagnostic disable-next-line: unused-local + stream = function(err, data, job) + if err then + on_complete(err) + return + end + if not data then + return + end + for line in data:gmatch("[^\r\n]+") do + if line:sub(1, 6) ~= "data: " then + return + end + vim.schedule(function() + local success, parsed = pcall(fn.json_decode, line:sub(7)) + if not success then + error("Error: failed to parse json: " .. parsed) + return + end + if parsed and parsed.type == "content_block_delta" then + on_chunk(parsed.delta.text) + elseif parsed and parsed.type == "message_stop" then + -- Stream request completed + on_complete(nil) + elseif parsed and parsed.type == "error" then + -- Stream request completed + on_complete(parsed) + end + end) + end + end, + headers = headers, + body = fn.json_encode(body), + }) +end + +local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + local api_key = os.getenv("OPENAI_API_KEY") + if not api_key then + error("OPENAI_API_KEY environment variable is not set") + end + + local user_prompt = base_user_prompt + .. "\n\nQUESTION:\n" + .. question + .. "\n\nCODE:\n" + .. "```" + .. code_lang + .. "\n" + .. code_content + .. "\n```" + + local url, headers, body + if config.get().provider == "azure" then + api_key = os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") + if not api_key then + error("Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable.") + end + url = config.get().azure.endpoint + .. "/openai/deployments/" + .. config.get().azure.deployment + .. "/chat/completions?api-version=" + .. config.get().azure.api_version + headers = { + ["Content-Type"] = "application/json", + ["api-key"] = api_key, + } + body = { + messages = { + { role = "system", content = system_prompt }, + { role = "user", content = user_prompt }, + }, + temperature = config.get().azure.temperature, + max_tokens = config.get().azure.max_tokens, + stream = true, + } + else + url = utils.trim_suffix(config.get().openai.endpoint, "/") .. "/v1/chat/completions" + headers = { + ["Content-Type"] = "application/json", + ["Authorization"] = "Bearer " .. api_key, + } + body = { + model = config.get().openai.model, + messages = { + { role = "system", content = system_prompt }, + { role = "user", content = user_prompt }, + }, + temperature = config.get().openai.temperature, + max_tokens = config.get().openai.max_tokens, + stream = true, + } + end + + print("Sending request to " .. (config.get().provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...") + + curl.post(url, { + ---@diagnostic disable-next-line: unused-local + stream = function(err, data, job) + if err then + on_complete(err) + return + end + if not data then + return + end + for line in data:gmatch("[^\r\n]+") do + if line:sub(1, 6) ~= "data: " then + return + end + vim.schedule(function() + local piece = line:sub(7) + local success, parsed = pcall(fn.json_decode, piece) + if not success then + if piece == "[DONE]" then + on_complete(nil) + return + end + error("Error: failed to parse json: " .. parsed) + return + end + if parsed and parsed.choices and parsed.choices[1].delta.content then + on_chunk(parsed.choices[1].delta.content) + elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then + -- Stream request completed + on_complete(nil) + end + end) + end + end, + headers = headers, + body = fn.json_encode(body), + }) +end + +function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + if config.get().provider == "openai" or config.get().provider == "azure" then + call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) + elseif config.get().provider == "claude" then + call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) + end +end + +return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 035d4bb..be3f3b1 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -1,17 +1,56 @@ local M = {} -local curl = require("plenary.curl") local Path = require("plenary.path") local n = require("nui-components") local diff = require("avante.diff") -local utils = require("avante.utils") local tiktoken = require("avante.tiktoken") local config = require("avante.config") +local ai_bot = require("avante.ai_bot") local api = vim.api local fn = vim.fn local RESULT_BUF_NAME = "AVANTE_RESULT" local CONFLICT_BUF_NAME = "AVANTE_CONFLICT" +local NAMESPACE = vim.api.nvim_create_namespace("AVANTE_CODEBLOCK") + +local function parse_codeblocks(buf) + local codeblocks = {} + local in_codeblock = false + local start_line = nil + local lang = nil + + local lines = vim.api.nvim_buf_get_lines(buf, 0, -1, false) + for i, line in ipairs(lines) do + if line:match("^```") then + -- parse language + local lang_ = line:match("^```(%w+)") + if in_codeblock and not lang_ then + table.insert(codeblocks, { start_line = start_line, end_line = i - 1, lang = lang }) + in_codeblock = false + elseif lang_ then + lang = lang_ + start_line = i - 1 + in_codeblock = true + end + end + end + + return codeblocks +end + +local function is_cursor_in_codeblock(codeblocks) + local cursor_pos = vim.api.nvim_win_get_cursor(0) + local cursor_line = cursor_pos[1] - 1 -- 转换为 0-indexed 行号 + + for _, block in ipairs(codeblocks) do + if cursor_line >= block.start_line and cursor_line <= block.end_line then + return block + end + end + + return nil +end + local function create_result_buf() local buf = api.nvim_create_buf(false, true) api.nvim_set_option_value("filetype", "markdown", { buf = buf }) @@ -82,7 +121,8 @@ local function get_cur_code_buf_content() print("Error: cannot get code buffer") return {} end - return api.nvim_buf_get_lines(code_buf, 0, -1, false) + local lines = api.nvim_buf_get_lines(code_buf, 0, -1, false) + return table.concat(lines, "\n") end local function prepend_line_number(content) @@ -138,258 +178,6 @@ local function extract_code_snippets(content) return snippets end -local system_prompt = [[ -You are an excellent programming expert. -]] - -local base_user_prompt = [[ -Your primary task is to suggest code modifications with precise line number ranges. Follow these instructions meticulously: - -1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones. - -2. When suggesting modifications: - a. Explain why the change is necessary or beneficial. - b. Provide the exact code snippet to be replaced using this format: - -Replace lines: {{start_line}}-{{end_line}} -```{{language}} -{{suggested_code}} -``` - -3. Crucial guidelines for line numbers: - - The range {{start_line}}-{{end_line}} is INCLUSIVE. Both start_line and end_line are included in the replacement. - - Count EVERY line, including empty lines, comments, and the LAST line of the file. - - For single-line changes, use the same number for start and end lines. - - For multi-line changes, ensure the range covers ALL affected lines, from the very first to the very last. - - Include the entire block (e.g., complete function) when modifying structured code. - - Pay special attention to the start_line, ensuring it's not omitted or incorrectly set. - - Double-check that your start_line is correct, especially for changes at the beginning of the file. - - Also, be careful with the end_line, especially when it's the last line of the file. - - Double-check that your line numbers align perfectly with the original code structure. - -4. Context and verification: - - Show 1-2 unchanged lines before and after each modification as context. - - These context lines are NOT included in the replacement range. - - After each suggestion, recount the lines to verify the accuracy of your line numbers. - - Double-check that both the start_line and end_line are correct for each modification. - - Verify that your suggested changes align perfectly with the original code structure. - -5. Final check: - - Review all suggestions, ensuring each line number is correct, especially the start_line and end_line. - - Pay extra attention to the start_line of each modification, ensuring it hasn't shifted down. - - Confirm that no unrelated code is accidentally modified or deleted. - - Verify that the start_line and end_line correctly include all intended lines for replacement. - - If a modification involves the first or last line of the file, explicitly state this in your explanation. - - Perform a final alignment check to ensure your line numbers haven't shifted, especially the start_line. - - Double-check that your line numbers align perfectly with the original code structure. - - Do not show the content after these modifications. - -Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift. -]] - -local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) - local api_key = os.getenv("ANTHROPIC_API_KEY") - if not api_key then - error("ANTHROPIC_API_KEY environment variable is not set") - end - - local user_prompt = base_user_prompt - - local tokens = config.get().claude.max_tokens - local headers = { - ["Content-Type"] = "application/json", - ["x-api-key"] = api_key, - ["anthropic-version"] = "2023-06-01", - ["anthropic-beta"] = "prompt-caching-2024-07-31", - } - - local code_prompt_obj = { - type = "text", - text = string.format("```%s\n%s```", code_lang, code_content), - } - - local user_prompt_obj = { - type = "text", - text = user_prompt, - } - - if tiktoken.count(code_prompt_obj.text) > 1024 then - code_prompt_obj.cache_control = { type = "ephemeral" } - end - - if tiktoken.count(user_prompt_obj.text) > 1024 then - user_prompt_obj.cache_control = { type = "ephemeral" } - end - - local body = { - model = config.get().claude.model, - system = system_prompt, - messages = { - { - role = "user", - content = { - code_prompt_obj, - { - type = "text", - text = string.format("%s", question), - }, - user_prompt_obj, - }, - }, - }, - stream = true, - temperature = config.get().claude.temperature, - max_tokens = tokens, - } - - local url = utils.trim_suffix(config.get().claude.endpoint, "/") .. "/v1/messages" - - print("Sending request to Claude API...") - - curl.post(url, { - ---@diagnostic disable-next-line: unused-local - stream = function(err, data, job) - if err then - on_complete(err) - return - end - if not data then - return - end - for line in data:gmatch("[^\r\n]+") do - if line:sub(1, 6) ~= "data: " then - return - end - vim.schedule(function() - local success, parsed = pcall(fn.json_decode, line:sub(7)) - if not success then - error("Error: failed to parse json: " .. parsed) - return - end - if parsed and parsed.type == "content_block_delta" then - on_chunk(parsed.delta.text) - elseif parsed and parsed.type == "message_stop" then - -- Stream request completed - on_complete(nil) - elseif parsed and parsed.type == "error" then - -- Stream request completed - on_complete(parsed) - end - end) - end - end, - headers = headers, - body = fn.json_encode(body), - }) -end - -local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - local api_key = os.getenv("OPENAI_API_KEY") - if not api_key then - error("OPENAI_API_KEY environment variable is not set") - end - - local user_prompt = base_user_prompt - .. "\n\nQUESTION:\n" - .. question - .. "\n\nCODE:\n" - .. "```" - .. code_lang - .. "\n" - .. code_content - .. "\n```" - - local url, headers, body - if config.get().provider == "azure" then - api_key = os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") - if not api_key then - error("Azure OpenAI API key is not set. Please set AZURE_OPENAI_API_KEY or OPENAI_API_KEY environment variable.") - end - url = config.get().azure.endpoint - .. "/openai/deployments/" - .. config.get().azure.deployment - .. "/chat/completions?api-version=" - .. config.get().azure.api_version - headers = { - ["Content-Type"] = "application/json", - ["api-key"] = api_key, - } - body = { - messages = { - { role = "system", content = system_prompt }, - { role = "user", content = user_prompt }, - }, - temperature = config.get().azure.temperature, - max_tokens = config.get().azure.max_tokens, - stream = true, - } - else - url = utils.trim_suffix(config.get().openai.endpoint, "/") .. "/v1/chat/completions" - headers = { - ["Content-Type"] = "application/json", - ["Authorization"] = "Bearer " .. api_key, - } - body = { - model = config.get().openai.model, - messages = { - { role = "system", content = system_prompt }, - { role = "user", content = user_prompt }, - }, - temperature = config.get().openai.temperature, - max_tokens = config.get().openai.max_tokens, - stream = true, - } - end - - print("Sending request to " .. (config.get().provider == "azure" and "Azure OpenAI" or "OpenAI") .. " API...") - - curl.post(url, { - ---@diagnostic disable-next-line: unused-local - stream = function(err, data, job) - if err then - on_complete(err) - return - end - if not data then - return - end - for line in data:gmatch("[^\r\n]+") do - if line:sub(1, 6) ~= "data: " then - return - end - vim.schedule(function() - local piece = line:sub(7) - local success, parsed = pcall(fn.json_decode, piece) - if not success then - if piece == "[DONE]" then - on_complete(nil) - return - end - error("Error: failed to parse json: " .. parsed) - return - end - if parsed and parsed.choices and parsed.choices[1].delta.content then - on_chunk(parsed.choices[1].delta.content) - elseif parsed and parsed.choices and parsed.choices[1].finish_reason == "stop" then - -- Stream request completed - on_complete(nil) - end - end) - end - end, - headers = headers, - body = fn.json_encode(body), - }) -end - -local function call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - if config.get().provider == "openai" or config.get().provider == "azure" then - call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete) - elseif config.get().provider == "claude" then - call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete) - end -end - local function update_result_buf_content(content) local current_win = api.nvim_get_current_win() local result_win = fn.bufwinid(result_buf) @@ -529,6 +317,41 @@ local function get_conflict_content(content, snippets) return result end +local function get_content_between_separators() + local separator = "---" + local bufnr = vim.api.nvim_get_current_buf() + local cursor_line = vim.api.nvim_win_get_cursor(0)[1] + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local start_line, end_line + + for i = cursor_line, 1, -1 do + if lines[i] == separator then + start_line = i + 1 + break + end + end + start_line = start_line or 1 + + for i = cursor_line, #lines do + if lines[i] == separator then + end_line = i - 1 + break + end + end + end_line = end_line or #lines + + if lines[cursor_line] == separator then + if cursor_line > 1 and lines[cursor_line - 1] ~= separator then + end_line = cursor_line - 1 + elseif cursor_line < #lines and lines[cursor_line + 1] ~= separator then + start_line = cursor_line + 1 + end + end + + local content = table.concat(vim.list_slice(lines, start_line, end_line), "\n") + return content +end + local get_renderer_size_and_position = function() local renderer_width = math.ceil(vim.o.columns * 0.3) local renderer_height = vim.o.lines @@ -543,6 +366,82 @@ function M.render_sidebar() result_buf = create_result_buf() + local current_apply_extmark_id = nil + + local function show_apply_button(block) + if current_apply_extmark_id then + api.nvim_buf_del_extmark(result_buf, NAMESPACE, current_apply_extmark_id) + end + + current_apply_extmark_id = api.nvim_buf_set_extmark(result_buf, NAMESPACE, block.start_line, -1, { + virt_text = { { "[Press A to Apply these patches]", "Keyword" } }, + virt_text_pos = "right_align", + hl_group = "Keyword", + }) + end + + local function apply() + local code_buf = get_cur_code_buf() + if code_buf == nil then + error("Error: cannot get code buffer") + return + end + local content = get_cur_code_buf_content() + local response = get_content_between_separators() + local snippets = extract_code_snippets(response) + local conflict_content = get_conflict_content(content, snippets) + + vim.defer_fn(function() + api.nvim_buf_set_lines(code_buf, 0, -1, false, conflict_content) + local code_win = get_cur_code_win() + if code_win == nil then + error("Error: cannot get code window") + return + end + api.nvim_set_current_win(code_win) + api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) + diff.add_visited_buffer(code_buf) + diff.process(code_buf) + api.nvim_feedkeys("gg", "n", false) + vim.defer_fn(function() + vim.cmd("AvanteConflictNextConflict") + api.nvim_feedkeys("zz", "n", false) + end, 1000) + end, 10) + end + + local function bind_apply_key() + vim.keymap.set("n", "A", apply, { buffer = result_buf, noremap = true, silent = true }) + end + + local function unbind_apply_key() + pcall(vim.keymap.del, "n", "A", { buffer = result_buf }) + end + + local codeblocks = {} + + api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, { + buffer = result_buf, + callback = function() + local block = is_cursor_in_codeblock(codeblocks) + + if block then + show_apply_button(block) + bind_apply_key() + else + vim.api.nvim_buf_clear_namespace(result_buf, NAMESPACE, 0, -1) + unbind_apply_key() + end + end, + }) + + api.nvim_create_autocmd({ "BufEnter", "BufWritePost" }, { + buffer = result_buf, + callback = function() + codeblocks = parse_codeblocks(result_buf) + end, + }) + local renderer_width, renderer_height, renderer_position = get_renderer_size_and_position() local renderer = n.create_renderer({ @@ -596,7 +495,7 @@ function M.render_sidebar() error("Error: cannot get code buffer") return end - local content = table.concat(get_cur_code_buf_content(), "\n") + local content = get_cur_code_buf_content() local content_with_line_numbers = prepend_line_number(content) local full_response = "" @@ -604,7 +503,7 @@ function M.render_sidebar() local filetype = api.nvim_get_option_value("filetype", { buf = code_buf }) - call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) + ai_bot.call_ai_api_stream(user_input, filetype, content_with_line_numbers, function(chunk) full_response = full_response .. chunk update_result_buf_content( "## " .. timestamp .. "\n\n> " .. user_input:gsub("\n", "\n> ") .. "\n\n" .. full_response @@ -646,27 +545,6 @@ function M.render_sidebar() -- Save chat history table.insert(chat_history or {}, { timestamp = timestamp, requirement = user_input, response = full_response }) save_chat_history(chat_history) - - local snippets = extract_code_snippets(full_response) - local conflict_content = get_conflict_content(content, snippets) - - vim.defer_fn(function() - api.nvim_buf_set_lines(code_buf, 0, -1, false, conflict_content) - local code_win = get_cur_code_win() - if code_win == nil then - error("Error: cannot get code window") - return - end - api.nvim_set_current_win(code_win) - api.nvim_feedkeys(api.nvim_replace_termcodes("", true, false, true), "n", true) - diff.add_visited_buffer(code_buf) - diff.process(code_buf) - api.nvim_feedkeys("gg", "n", false) - vim.defer_fn(function() - vim.cmd("AvanteConflictNextConflict") - api.nvim_feedkeys("zz", "n", false) - end, 1000) - end, 10) end) end