diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index 5387362..ecf7beb 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -43,6 +43,7 @@ struct TemplateContext { system_info: Option, model_name: Option, memory: Option, + todos: Option, } // Given the file name registered after add, the context table in Lua, resulted in a formatted @@ -70,6 +71,7 @@ fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult< system_info => context.system_info, model_name => context.model_name, memory => context.memory, + todos => context.todos, }) .map_err(LuaError::external) .unwrap()) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 731e34f..24c9a42 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -129,6 +129,72 @@ function M.summarize_memory(prev_memory, history_messages, cb) }) end +---@param user_input string +---@param cb fun(error: string | nil): nil +function M.generate_todos(user_input, cb) + local system_prompt = + [[You are an expert coding assistant. Please generate a todo list to complete the task based on the user input and pass the todo list to the add_todos tool.]] + local messages = { + { role = "user", content = user_input }, + } + + local provider = Providers[Config.provider] + local tools = { + require("avante.llm_tools.add_todos"), + } + + local history_messages = {} + cb = Utils.call_once(cb) + + M.curl({ + provider = provider, + prompt_opts = { + system_prompt = system_prompt, + messages = messages, + tools = tools, + }, + handler_opts = { + on_start = function() end, + on_chunk = function() end, + on_messages_add = function(msgs) + msgs = vim.islist(msgs) and msgs or { msgs } + for _, msg in ipairs(msgs) do + if not msg.uuid then msg.uuid = Utils.uuid() end + local idx = nil + for i, m in ipairs(history_messages) do + if m.uuid == msg.uuid then + idx = i + break + end + end + if idx ~= nil then + history_messages[idx] = msg + else + table.insert(history_messages, msg) + end + end + end, + on_stop = function(stop_opts) + if stop_opts.error ~= nil then + Utils.error(string.format("generate todos failed: %s", vim.inspect(stop_opts.error))) + return + end + if stop_opts.reason == "tool_use" then + local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages) + for _, partial_tool_use in ipairs(uncalled_tool_uses) do + if partial_tool_use.state == "generated" and partial_tool_use.name == "add_todos" then + LLMTools.process_tool_use(tools, partial_tool_use, function() end, function() cb() end, {}) + cb() + end + end + else + cb() + end + end, + }, + }) +end + ---@param opts AvanteGeneratePromptsOptions ---@return AvantePromptOptions function M.generate_prompts(opts) @@ -202,6 +268,11 @@ function M.generate_prompts(opts) memory = opts.memory, } + if opts.get_todos then + local todos = opts.get_todos() + if todos and #todos > 0 then template_opts.todos = vim.json.encode(todos) end + end + local system_prompt if opts.prompt_opts and opts.prompt_opts.system_prompt then system_prompt = opts.prompt_opts.system_prompt @@ -492,8 +563,6 @@ function M.curl(opts) local retry_after = 10 if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end if result.status == 429 then - Utils.debug("result", result) - handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after }) return end @@ -709,74 +778,69 @@ function M._stream(opts) end return opts.on_stop({ reason = "cancelled" }) end - local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[] - local tool_result_seen = {} local history_messages = opts.get_history_messages and opts.get_history_messages({ all = true }) or {} - for idx = #history_messages, 1, -1 do - local message = history_messages[idx] - local content = message.message.content - if type(content) ~= "table" or #content == 0 then goto continue end - local is_break = false - for _, item in ipairs(content) do - if item.type == "tool_use" then - if not tool_result_seen[item.id] then - local partial_tool_use = { - name = item.name, - id = item.id, - input = item.input, - state = message.state, - } - table.insert(partial_tool_use_list, 1, partial_tool_use) - else - is_break = true - break - end - end - if item.type == "tool_result" then tool_result_seen[item.tool_use_id] = true end - end - if is_break then break end - ::continue:: - end + local uncalled_tool_uses = Utils.get_uncalled_tool_uses(history_messages) if stop_opts.reason == "complete" and Config.mode == "agentic" then - if #partial_tool_use_list == 0 then - local completed_attempt_completion_tool_use = nil - for idx = #history_messages, 1, -1 do - local message = history_messages[idx] - if message.is_user_submission then break end - if not Utils.is_tool_use_message(message) then goto continue end - if message.message.content[1].name ~= "attempt_completion" then break end - completed_attempt_completion_tool_use = message - if message then break end - ::continue:: - end - local user_reminder_count = opts.session_ctx.user_reminder_count or 0 - if not completed_attempt_completion_tool_use and opts.on_messages_add and user_reminder_count < 3 then - opts.session_ctx.user_reminder_count = user_reminder_count + 1 - local message = HistoryMessage:new({ + local completed_attempt_completion_tool_use = nil + for idx = #history_messages, 1, -1 do + local message = history_messages[idx] + if message.is_user_submission then break end + if not Utils.is_tool_use_message(message) then goto continue end + if message.message.content[1].name ~= "attempt_completion" then break end + completed_attempt_completion_tool_use = message + if message then break end + ::continue:: + end + local unfinished_todos = {} + if opts.get_todos then + local todos = opts.get_todos() + unfinished_todos = vim.tbl_filter( + function(todo) return todo.status ~= "done" or todo.status ~= "cancelled" end, + todos + ) + end + local user_reminder_count = opts.session_ctx.user_reminder_count or 0 + if + not completed_attempt_completion_tool_use + and opts.on_messages_add + and (user_reminder_count < 3 or #unfinished_todos > 0) + then + opts.session_ctx.user_reminder_count = user_reminder_count + 1 + Utils.debug("user reminder count", user_reminder_count) + local message + if #unfinished_todos > 0 then + message = HistoryMessage:new({ + role = "user", + content = "You should use tool calls to answer the question, for example, use update_todo_status if the task step is done or cancelled.", + }, { + visible = false, + }) + else + message = HistoryMessage:new({ role = "user", content = "You should use tool calls to answer the question, for example, use attempt_completion if the job is done.", }, { visible = false, }) - opts.on_messages_add({ message }) - local new_opts = vim.tbl_deep_extend("force", opts, { - history_messages = opts.get_history_messages(), - }) - if provider.get_rate_limit_sleep_time then - local sleep_time = provider:get_rate_limit_sleep_time(resp_headers) - if sleep_time and sleep_time > 0 then - Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...") - vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000) - return - end - end - M._stream(new_opts) - return end + opts.on_messages_add({ message }) + local new_opts = vim.tbl_deep_extend("force", opts, { + history_messages = opts.get_history_messages(), + }) + if provider.get_rate_limit_sleep_time then + local sleep_time = provider:get_rate_limit_sleep_time(resp_headers) + if sleep_time and sleep_time > 0 then + Utils.info("Rate limit reached. Sleeping for " .. sleep_time .. " seconds ...") + vim.defer_fn(function() M._stream(new_opts) end, sleep_time * 1000) + return + end + end + M._stream(new_opts) + return end end if stop_opts.reason == "tool_use" then - return handle_next_tool_use(partial_tool_use_list, 1, {}, stop_opts.streaming_tool_use) + return handle_next_tool_use(uncalled_tool_uses, 1, {}, stop_opts.streaming_tool_use) end if stop_opts.reason == "rate_limit" then local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*" diff --git a/lua/avante/llm_tools/add_todos.lua b/lua/avante/llm_tools/add_todos.lua new file mode 100644 index 0000000..280b520 --- /dev/null +++ b/lua/avante/llm_tools/add_todos.lua @@ -0,0 +1,81 @@ +local Base = require("avante.llm_tools.base") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "add_todos" + +M.description = "Add TODOs to the current task" + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "todos", + description = "The TODOs to add", + type = "array", + items = { + name = "items", + type = "object", + fields = { + { + name = "id", + description = "The ID of the TODO", + type = "string", + }, + { + name = "content", + description = "The content of the TODO", + type = "string", + }, + { + name = "status", + description = "The status of the TODO", + type = "string", + choices = { "todo", "doing", "done", "cancelled" }, + }, + { + name = "priority", + description = "The priority of the TODO", + type = "string", + choices = { "low", "medium", "high" }, + }, + }, + }, + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "Whether the TODOs were added successfully", + type = "boolean", + }, + { + name = "error", + description = "Error message if the TODOs could not be updated", + type = "string", + optional = true, + }, +} + +M.on_render = function() return {} end + +---@type AvanteLLMToolFunc<{ todos: avante.TODO[] }> +function M.func(opts, on_log, on_complete, session_ctx) + local sidebar = require("avante").get() + if not sidebar then return false, "Avante sidebar not found" end + local todos = opts.todos + if not todos or #todos == 0 then return false, "No todos provided" end + sidebar:update_todos(todos) + if on_complete then + on_complete(true, nil) + return nil, nil + end + return true, nil +end + +return M diff --git a/lua/avante/llm_tools/init.lua b/lua/avante/llm_tools/init.lua index 35b508c..a15347a 100644 --- a/lua/avante/llm_tools/init.lua +++ b/lua/avante/llm_tools/init.lua @@ -762,6 +762,8 @@ M._tools = { require("avante.llm_tools.ls"), require("avante.llm_tools.grep"), require("avante.llm_tools.delete_tool_use_messages"), + require("avante.llm_tools.add_todos"), + require("avante.llm_tools.update_todo_status"), { name = "read_file_toplevel_symbols", description = "Read the top-level symbols of a file in current project scope", diff --git a/lua/avante/llm_tools/update_todo_status.lua b/lua/avante/llm_tools/update_todo_status.lua new file mode 100644 index 0000000..d6b49c2 --- /dev/null +++ b/lua/avante/llm_tools/update_todo_status.lua @@ -0,0 +1,65 @@ +local Base = require("avante.llm_tools.base") + +---@class AvanteLLMTool +local M = setmetatable({}, Base) + +M.name = "update_todo_status" + +M.description = "Update the status of TODO" + +---@type AvanteLLMToolParam +M.param = { + type = "table", + fields = { + { + name = "id", + description = "The ID of the TODO to update", + type = "string", + }, + { + name = "status", + description = "The status of the TODO to update", + type = "string", + choices = { "todo", "doing", "done", "cancelled" }, + }, + }, +} + +---@type AvanteLLMToolReturn[] +M.returns = { + { + name = "success", + description = "Whether the TODO was updated successfully", + type = "boolean", + }, + { + name = "error", + description = "Error message if the TODOs could not be updated", + type = "string", + optional = true, + }, +} + +M.on_render = function() return {} end + +---@type AvanteLLMToolFunc<{ id: string, status: string }> +function M.func(opts, on_log, on_complete, session_ctx) + local sidebar = require("avante").get() + if not sidebar then return false, "Avante sidebar not found" end + local todos = sidebar.chat_history.todos + if not todos or #todos == 0 then return false, "No todos found" end + for _, todo in ipairs(todos) do + if todo.id == opts.id then + todo.status = opts.status + break + end + end + sidebar:update_todos(todos) + if on_complete then + on_complete(true, nil) + return nil, nil + end + return true, nil +end + +return M diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 7100408..4bd2804 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -46,8 +46,9 @@ Sidebar.__index = Sidebar ---@field id integer ---@field augroup integer ---@field code avante.CodeState ----@field winids table<"result_container" | "selected_code_container" | "selected_files_container" | "input_container", integer> +---@field winids table<"result_container" | "todos_container" | "selected_code_container" | "selected_files_container" | "input_container", integer> ---@field result_container NuiSplit | nil +---@field todos_container NuiSplit | nil ---@field selected_code_container NuiSplit | nil ---@field selected_files_container NuiSplit | nil ---@field input_container NuiSplit | nil @@ -71,11 +72,13 @@ function Sidebar:new(id) code = { bufnr = 0, winid = 0, selection = nil, old_winhl = nil }, winids = { result_container = 0, + todos_container = 0, selected_files_container = 0, selected_code_container = 0, input_container = 0, }, result_container = nil, + todos_container = nil, selected_code_container = nil, selected_files_container = nil, input_container = nil, @@ -126,6 +129,7 @@ function Sidebar:reset() self.winids = { result_container = 0, selected_files_container = 0, selected_code_container = 0, input_container = 0 } self.result_container = nil + self.todos_container = nil self.selected_code_container = nil self.selected_files_container = nil self.input_container = nil @@ -1330,6 +1334,7 @@ function Sidebar:on_mount(opts) callback = function(args) local closed_winid = tonumber(args.match) if closed_winid == self.winids.selected_files_container then return end + if closed_winid == self.winids.todos_container then return end if not self:is_sidebar_winid(closed_winid) then return end self:close() end, @@ -1354,6 +1359,7 @@ function Sidebar:refresh_winids() if self.winids.result_container then table.insert(winids, self.winids.result_container) end if self.winids.selected_files_container then table.insert(winids, self.winids.selected_files_container) end if self.winids.selected_code_container then table.insert(winids, self.winids.selected_code_container) end + if self.winids.todos_container then table.insert(winids, self.winids.todos_container) end if self.winids.input_container then table.insert(winids, self.winids.input_container) end local function switch_windows() @@ -1885,6 +1891,7 @@ function Sidebar:new_chat(args, cb) self.current_state = nil self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end }) if cb then cb(args) end + vim.schedule(function() self:create_todos_container() end) end function Sidebar:save_history() Path.history.save(self.code.bufnr, self.chat_history) end @@ -1898,6 +1905,15 @@ function Sidebar:delete_history_messages(uuids) Path.history.save(self.code.bufnr, self.chat_history) end +---@param todos avante.TODO[] +function Sidebar:update_todos(todos) + if self.chat_history == nil then self:reload_chat_history() end + if self.chat_history == nil then return end + self.chat_history.todos = todos + Path.history.save(self.code.bufnr, self.chat_history) + self:create_todos_container() +end + ---@param messages avante.HistoryMessage | avante.HistoryMessage[] function Sidebar:add_history_messages(messages) local history_messages = Utils.get_history_messages(self.chat_history) @@ -2017,8 +2033,7 @@ function Sidebar:create_selected_code_container() api.nvim_win_get_height(self.result_container.winid) - selected_code_size - 3 ) end - self:adjust_result_container_layout() - self:adjust_selected_files_container_layout() + self:adjust_layout() end end @@ -2604,6 +2619,9 @@ function Sidebar:create_input_container() Path.history.save(self.code.bufnr, self.chat_history) end + local history_messages = Utils.get_history_messages(self.chat_history) + local is_first_request = #history_messages == 0 + if request and request ~= "" then self:add_history_messages({ HistoryMessage:new({ @@ -2627,6 +2645,10 @@ function Sidebar:create_input_container() on_messages_add = on_messages_add, on_state_change = on_state_change, get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end, + get_todos = function() + local history = Path.history.load(self.code.bufnr) + return history and history.todos or {} + end, session_ctx = {}, }) @@ -2867,13 +2889,23 @@ function Sidebar:get_selected_files_size() return selected_files_size end +function Sidebar:get_todos_container_height() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then return 0 end + return 3 +end + function Sidebar:get_result_container_height() + local todos_height = self:get_todos_container_height() local selected_code_size = self:get_selected_code_size() local selected_files_size = self:get_selected_files_size() if self:get_layout() == "horizontal" then return math.floor(Config.windows.height / 100 * vim.o.lines) end - return math.max(1, api.nvim_win_get_height(self.code.winid) - selected_files_size - selected_code_size - 3 - 8) + return math.max( + 1, + api.nvim_win_get_height(self.code.winid) - selected_files_size - selected_code_size - todos_height - 3 - 8 + ) end function Sidebar:get_result_container_width() @@ -2948,6 +2980,8 @@ function Sidebar:render(opts) self:create_selected_code_container() + self:create_todos_container() + self:on_mount(opts) self:setup_colors() @@ -2967,6 +3001,13 @@ function Sidebar:adjust_selected_files_container_layout() api.nvim_win_set_height(self.selected_files_container.winid, win_height) end +function Sidebar:adjust_todos_container_layout() + if not Utils.is_valid_container(self.todos_container, true) then return end + + local win_height = self:get_todos_container_height() + api.nvim_win_set_height(self.todos_container.winid, win_height) +end + function Sidebar:create_selected_files_container() if self.selected_files_container then self.selected_files_container:unmount() end @@ -2995,7 +3036,6 @@ function Sidebar:create_selected_files_container() }), position = "top", size = { - width = "40%", height = 2, }, }) @@ -3057,7 +3097,7 @@ function Sidebar:create_selected_files_container() Highlights.SUBTITLE, Highlights.REVERSED_SUBTITLE ) - self:adjust_result_container_layout() + self:adjust_layout() end self.file_selector:on("update", render) @@ -3095,4 +3135,76 @@ function Sidebar:create_selected_files_container() render() end +function Sidebar:create_todos_container() + local history = Path.history.load(self.code.bufnr) + if not history or not history.todos or #history.todos == 0 then + if self.todos_container then self.todos_container:unmount() end + self.todos_container = nil + self:adjust_layout() + self:refresh_winids() + return + end + if not self.todos_container then + self.todos_container = Split({ + enter = false, + relative = { + type = "win", + winid = self.input_container.winid, + }, + buf_options = vim.tbl_deep_extend("force", buf_options, { + modifiable = false, + swapfile = false, + buftype = "nofile", + bufhidden = "wipe", + filetype = "AvanteTodos", + }), + win_options = vim.tbl_deep_extend("force", base_win_options, { + fillchars = Config.windows.fillchars, + }), + position = "top", + size = { + height = 3, + }, + }) + self.todos_container:mount() + end + local done_count = 0 + local total_count = #history.todos + local focused_idx = 1 + local todos_content_lines = {} + for idx, todo in ipairs(history.todos) do + local status_content = "[ ]" + if todo.status == "done" then + done_count = done_count + 1 + status_content = "[x]" + end + if todo.status == "doing" then status_content = "[-]" end + local line = string.format("%s %d. %s", status_content, idx, todo.content) + if todo.status == "cancelled" then line = "~~" .. line .. "~~" end + if todo.status ~= "todo" then focused_idx = idx + 1 end + table.insert(todos_content_lines, line) + end + if focused_idx > #todos_content_lines then focused_idx = #todos_content_lines end + local todos_buf = api.nvim_win_get_buf(self.todos_container.winid) + Utils.unlock_buf(todos_buf) + api.nvim_buf_set_lines(todos_buf, 0, -1, false, todos_content_lines) + api.nvim_win_set_cursor(self.todos_container.winid, { focused_idx, 0 }) + Utils.lock_buf(todos_buf) + self:render_header( + self.todos_container.winid, + todos_buf, + Utils.icon(" ") .. "Todos" .. " (" .. done_count .. "/" .. total_count .. ")", + Highlights.SUBTITLE, + Highlights.REVERSED_SUBTITLE + ) + self:adjust_layout() + self:refresh_winids() +end + +function Sidebar:adjust_layout() + self:adjust_result_container_layout() + self:adjust_todos_container_layout() + self:adjust_selected_files_container_layout() +end + return Sidebar diff --git a/lua/avante/templates/_environments.avanterules b/lua/avante/templates/_environments.avanterules new file mode 100644 index 0000000..3fdae84 --- /dev/null +++ b/lua/avante/templates/_environments.avanterules @@ -0,0 +1,7 @@ +{% if system_info -%} +==== + +SYSTEM INFORMATION + +{{system_info}} +{%- endif %} diff --git a/lua/avante/templates/_task-guidelines.avanterules b/lua/avante/templates/_task-guidelines.avanterules new file mode 100644 index 0000000..ec2735f --- /dev/null +++ b/lua/avante/templates/_task-guidelines.avanterules @@ -0,0 +1,67 @@ +{% if todos -%} +==== + +# Task TODOs + +{{todos}} + +==== +{%- endif %} + +# Task Management +You have access to the add_todos and update_todo_status tools to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress. +These tools are also EXTREMELY helpful for planning tasks, and for breaking down larger complex tasks into smaller steps. If you do not use this tool when planning, you may forget to do important tasks - and that is unacceptable. + +It is critical that you mark todos as completed as soon as you are done with a task. Do not batch up multiple tasks before marking them as completed. + +Examples: + + +user: Run the build and fix any type errors +assistant: I'm going to use the add_todos tool to write the following items to the todo list: +- Run the build +- Fix any type errors + +I'm now going to run the build using Bash. + +Looks like I found 10 type errors. I'm going to use the add_todos tool to write 10 items to the todo list. + +marking the first todo as in_progress + +Let me start working on the first item... + +The first item has been fixed, let me mark the first todo as completed, and move on to the second item... +.. +.. + +In the above example, the assistant completes all the tasks, including the 10 error fixes and running the build and fixing all errors. + + +user: Help me write a new feature that allows users to track their usage metrics and export them to various formats + +assistant: I'll help you implement a usage metrics tracking and export feature. Let me first use the add_todos tool to plan this task. +Adding the following todos to the todo list: +1. Research existing metrics tracking in the codebase +2. Design the metrics collection system +3. Implement core metrics tracking functionality +4. Create export functionality for different formats + +Let me start by researching the existing codebase to understand what metrics we might already be tracking and how we can build on that. + +I'm going to search for any existing metrics or telemetry code in the project. + +I've found some existing telemetry code. Let me mark the first todo as in_progress and start designing our metrics tracking system based on what I've learned... + +[Assistant continues implementing the feature step by step, marking todos as in_progress and completed as they go] + + + +# Doing tasks +The user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended: +- Use the add_todos tool to plan the task if required +- Use the update_todo_status tool to mark todos as doing, done, or cancelled +- Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially. +- Implement the solution using all tools available to you +- Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach. +- VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) with Bash if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CLAUDE.md so that you will know to run it next time. +NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive. diff --git a/lua/avante/templates/agentic.avanterules b/lua/avante/templates/agentic.avanterules index 51efb25..739b788 100644 --- a/lua/avante/templates/agentic.avanterules +++ b/lua/avante/templates/agentic.avanterules @@ -1,10 +1,15 @@ {% extends "base.avanterules" %} + {% block extra_prompt %} +{% include "_task-guidelines.avanterules" %} + ==== RULES +- Strictly follow the TODOs step by step to complete the task without stopping, and after completing each step, use the update_todo_status tool to update the status of the TODOs. + - NEVER reply the updated code. - Always reply to the user in the same language they are using. @@ -19,6 +24,8 @@ RULES - NEVER end attempt_completion result with a question or request to engage in further conversation! Formulate the end of your result in a way that is final and does not require further input from the user. +- Ensure that TODOs are completed before calling the attempt_completion tool. + ==== OBJECTIVE diff --git a/lua/avante/templates/base.avanterules b/lua/avante/templates/base.avanterules index 3a8a144..de5c3ce 100644 --- a/lua/avante/templates/base.avanterules +++ b/lua/avante/templates/base.avanterules @@ -6,15 +6,13 @@ Make sure code comments are in English when generating them. Memory is crucial, you must follow the instructions in ! -{% include "_tools-guidelines.avanterules" %} +{% include "_environments.avanterules" %} -{% if system_info -%} ==== -SYSTEM INFORMATION +{% include "_tools-guidelines.avanterules" %} -{{system_info}} -{%- endif %} +==== {% block extra_prompt %} {% endblock %} diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 863a394..db522c7 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -86,6 +86,12 @@ vim.g.avante_login = vim.g.avante_login ---@field role "user" | "assistant" ---@field content AvanteLLMMessageContent +---@class avante.TODO +---@field id string +---@field content string +---@field status "todo" | "doing" | "done" | "cancelled" +---@field priority "low" | "medium" | "high" + ---@class avante.HistoryMessage ---@field message AvanteLLMMessage ---@field timestamp string @@ -342,6 +348,7 @@ vim.g.avante_login = vim.g.avante_login ---@field selected_filepaths string[] | nil ---@field diagnostics string | nil ---@field history_messages avante.HistoryMessage[] | nil +---@field get_todos? fun(): avante.TODO[] ---@field memory string | nil --- ---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions @@ -404,8 +411,10 @@ vim.g.avante_login = vim.g.avante_login ---@field name string ---@field description? string ---@field get_description? fun(): string ----@field type 'string' | 'integer' | 'boolean' | 'object' +---@field type 'string' | 'integer' | 'boolean' | 'object' | 'array' ---@field fields? AvanteLLMToolParamField[] +---@field items? AvanteLLMToolParamField +---@field choices? string[] ---@field optional? boolean ---@class AvanteLLMToolReturn @@ -431,6 +440,7 @@ vim.g.avante_login = vim.g.avante_login ---@field timestamp string ---@field messages avante.HistoryMessage[] | nil ---@field entries avante.ChatHistoryEntry[] | nil +---@field todos avante.TODO[] | nil ---@field memory avante.ChatMemory | nil ---@field filename string ---@field system_prompt string | nil diff --git a/lua/avante/ui/confirm.lua b/lua/avante/ui/confirm.lua index cf92fab..a2cedb7 100644 --- a/lua/avante/ui/confirm.lua +++ b/lua/avante/ui/confirm.lua @@ -33,6 +33,7 @@ function M:new(message, callback, opts) end function M:open() + if self._popup then return end self._prev_winid = vim.api.nvim_get_current_win() local message = self.message local callback = self.callback diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 0df5c44..60cd1c8 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1293,11 +1293,20 @@ function M.llm_tool_param_fields_to_json_schema(fields) properties = properties_, required = required_, } + elseif field.type == "array" and field.items then + local properties_ = M.llm_tool_param_fields_to_json_schema({ field.items }) + local _, obj = next(properties_) + properties[field.name] = { + type = field.type, + description = field.get_description and field.get_description() or field.description, + items = obj, + } else properties[field.name] = { type = field.type, description = field.get_description and field.get_description() or field.description, } + if field.choices then properties[field.name].enum = field.choices end end if not field.optional then table.insert(required, field.name) end end @@ -1723,4 +1732,46 @@ function M.tbl_override(value, override) return vim.tbl_extend("force", value, override) end +---@param history_messages avante.HistoryMessage[] +---@return AvantePartialLLMToolUse[] +function M.get_uncalled_tool_uses(history_messages) + local partial_tool_use_list = {} ---@type AvantePartialLLMToolUse[] + local tool_result_seen = {} + for idx = #history_messages, 1, -1 do + local message = history_messages[idx] + local content = message.message.content + if type(content) ~= "table" or #content == 0 then goto continue end + local is_break = false + for _, item in ipairs(content) do + if item.type == "tool_use" then + if not tool_result_seen[item.id] then + local partial_tool_use = { + name = item.name, + id = item.id, + input = item.input, + state = message.state, + } + table.insert(partial_tool_use_list, 1, partial_tool_use) + else + is_break = true + break + end + end + if item.type == "tool_result" then tool_result_seen[item.tool_use_id] = true end + end + if is_break then break end + ::continue:: + end + return partial_tool_use_list +end + +function M.call_once(func) + local called = false + return function(...) + if called then return end + called = true + return func(...) + end +end + return M