diff --git a/README.md b/README.md index ce4fe1e..90b378c 100644 --- a/README.md +++ b/README.md @@ -1150,6 +1150,82 @@ providers = { } ``` +## ACP Support + +Avante.nvim now supports the [Agent Client Protocol (ACP)](https://agentclientprotocol.com/overview/introduction), enabling seamless integration with AI agents that follow this standardized communication protocol. ACP provides a unified way for AI agents to interact with development environments, offering enhanced capabilities for code editing, file operations, and tool execution. + +### What is ACP? + +The Agent Client Protocol (ACP) is a standardized protocol that enables AI agents to communicate with development tools and environments. It provides: + +- **Standardized Communication**: A unified JSON-RPC based protocol for agent-client interactions +- **Tool Integration**: Support for various development tools like file operations, code execution, and search +- **Session Management**: Persistent sessions that maintain context across interactions +- **Permission System**: Granular control over what agents can access and modify + +### Enabling ACP + +To use ACP-compatible agents with Avante.nvim, you need to configure an ACP provider. Here are the currently supported ACP agents: + +#### Gemini CLI with ACP +```lua +require('avante').setup({ + provider = "gemini-cli", + -- other configuration options... +}) +``` + +#### Claude Code with ACP +```lua +require('avante').setup({ + provider = "claude-code", + -- other configuration options... +}) +``` + +### ACP Configuration + +ACP providers are configured in the `acp_providers` section of your configuration: + +```lua +require('avante').setup({ + acp_providers = { + ["gemini-cli"] = { + command = "gemini", + args = { "--experimental-acp" }, + env = { + NODE_NO_WARNINGS = "1", + GEMINI_API_KEY = os.getenv("GEMINI_API_KEY"), + }, + }, + ["claude-code"] = { + command = "npx", + args = { "acp-claude-code" }, + env = { + NODE_NO_WARNINGS = "1", + ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY"), + }, + }, + }, +}) +``` + +### Prerequisites + +Before using ACP agents, ensure you have the required tools installed: + +- **For Gemini CLI**: Install the `gemini` CLI tool and set your `GEMINI_API_KEY` +- **For Claude Code**: Install the `acp-claude-code` package via npm and set your `ANTHROPIC_API_KEY` + +### ACP vs Traditional Providers + +ACP providers offer several advantages over traditional API-based providers: + +- **Enhanced Tool Access**: Agents can directly interact with your file system, run commands, and access development tools +- **Persistent Context**: Sessions maintain state across multiple interactions +- **Fine-grained Permissions**: Control exactly what agents can access and modify +- **Standardized Protocol**: Compatible with any ACP-compliant agent + ## Custom providers Avante provides a set of default providers, but users can also create their own providers. diff --git a/lua/avante/config.lua b/lua/avante/config.lua index a1d9c7b..eac2f34 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -237,6 +237,28 @@ M._defaults = { }, }, }, + acp_providers = { + ["gemini-cli"] = { + command = "gemini", + args = { "--experimental-acp" }, + env = { + NODE_NO_WARNINGS = "1", + GEMINI_API_KEY = os.getenv("GEMINI_API_KEY"), + }, + auth_method = "gemini-api-key", + }, + ["claude-code"] = { + command = "npx", + args = { "acp-claude-code" }, + env = { + NODE_NO_WARNINGS = "1", + ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY"), + ANTHROPIC_BASE_URL = os.getenv("ANTHROPIC_BASE_URL"), + ACP_PATH_TO_CLAUDE_CODE_EXECUTABLE = vim.fn.exepath("claude"), + ACP_PERMISSION_MODE = "bypassPermissions", + }, + }, + }, ---To add support for custom provider, follow the format below ---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details ---@type {[string]: AvanteProvider} @@ -466,7 +488,7 @@ M._defaults = { use_cwd_as_project_root = false, auto_focus_on_diff_view = false, ---@type boolean | string[] -- true: auto-approve all tools, false: normal prompts, string[]: auto-approve specific tools by name - auto_approve_tool_permissions = false, -- Default: show permission prompts for all tools + auto_approve_tool_permissions = true, -- Default: show permission prompts for all tools auto_check_diagnostics = true, enable_fastapply = false, }, @@ -774,6 +796,7 @@ end local function apply_model_selection(config, model_name, provider_name) local provider_list = config.providers or {} local current_provider_name = config.provider + if config.acp_providers[current_provider_name] then return end local target_provider_name = provider_name or current_provider_name local target_provider = provider_list[target_provider_name] diff --git a/lua/avante/history/render.lua b/lua/avante/history/render.lua index 1a63ae7..f4dd758 100644 --- a/lua/avante/history/render.lua +++ b/lua/avante/history/render.lua @@ -1,17 +1,26 @@ local Helpers = require("avante.history.helpers") local Line = require("avante.ui.line") local Utils = require("avante.utils") +local Highlights = require("avante.highlights") local M = {} +---@diagnostic disable-next-line: deprecated +local islist = vim.islist or vim.tbl_islist + ---Converts text into format suitable for UI ---@param text string +---@param decoration string | nil ---@return avante.ui.Line[] -local function text_to_lines(text) +local function text_to_lines(text, decoration) local text_lines = vim.split(text, "\n") local lines = {} for _, text_line in ipairs(text_lines) do - table.insert(lines, Line:new({ { text_line } })) + if decoration then + table.insert(lines, Line:new({ { decoration }, { text_line } })) + else + table.insert(lines, Line:new({ { text_line } })) + end end return lines end @@ -22,6 +31,10 @@ end local function thinking_to_lines(item) local text = item.thinking or item.data or "" local text_lines = vim.split(text, "\n") + --- trime suffix empty lines + while #text_lines > 0 and text_lines[#text_lines] == "" do + table.remove(text_lines, #text_lines) + end local ui_lines = {} table.insert(ui_lines, Line:new({ { Utils.icon("🤔 ") .. "Thought content:" } })) table.insert(ui_lines, Line:new({ { "" } })) @@ -35,7 +48,7 @@ end ---@param tool_name string ---@param logs string[] ---@return avante.ui.Line[] -local function tool_logs_to_lines(tool_name, logs) +function M.tool_logs_to_lines(tool_name, logs) local ui_lines = {} local num_logs = #logs @@ -44,7 +57,7 @@ local function tool_logs_to_lines(tool_name, logs) local num_lines = #log_lines for line_idx = 1, num_lines do - local decoration = (log_idx == num_logs and line_idx == num_lines) and "╰─ " or "│ " + local decoration = "│ " table.insert(ui_lines, Line:new({ { decoration }, { " " .. log_lines[line_idx] } })) end end @@ -57,14 +70,201 @@ local STATE_TO_HL = { succeeded = "AvanteStateSpinnerSucceeded", } +function M.get_diff_lines(old_str, new_str, decoration, truncate) + local lines = {} + local line_count = 0 + local old_lines = vim.split(old_str, "\n") + local new_lines = vim.split(new_str, "\n") + ---@diagnostic disable-next-line: assign-type-mismatch, missing-fields + local patch = vim.diff(old_str, new_str, { ---@type integer[][] + algorithm = "histogram", + result_type = "indices", + ctxlen = vim.o.scrolloff, + }) + local prev_start_a = 0 + for idx, hunk in ipairs(patch) do + if truncate and line_count > 10 then + table.insert( + lines, + Line:new({ + { decoration }, + { string.format("... (Result truncated, remaining %d hunks not shown)", #patch - idx + 1) }, + }) + ) + break + end + local start_a, count_a, start_b, count_b = unpack(hunk) + local no_change_lines = vim.list_slice(old_lines, prev_start_a, start_a - 1) + local last_tree_no_change_lines = vim.list_slice(no_change_lines, #no_change_lines - 3) + if #no_change_lines > 3 then table.insert(lines, Line:new({ { decoration }, { "..." } })) end + for _, line in ipairs(last_tree_no_change_lines) do + line_count = line_count + 1 + table.insert(lines, Line:new({ { decoration }, { line } })) + end + prev_start_a = start_a + count_a + if count_a > 0 then + local delete_lines = vim.list_slice(old_lines, start_a, start_a + count_a - 1) + for _, line in ipairs(delete_lines) do + line_count = line_count + 1 + table.insert(lines, Line:new({ { decoration }, { line, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } })) + end + end + if count_b > 0 then + local create_lines = vim.list_slice(new_lines, start_b, start_b + count_b - 1) + for _, line in ipairs(create_lines) do + line_count = line_count + 1 + table.insert(lines, Line:new({ { decoration }, { line, Highlights.INCOMING } })) + end + end + end + if prev_start_a < #old_lines then + -- Append remaining old_lines + local no_change_lines = vim.list_slice(old_lines, prev_start_a, #old_lines) + local first_tree_no_change_lines = vim.list_slice(no_change_lines, 1, 3) + for _, line in ipairs(first_tree_no_change_lines) do + line_count = line_count + 1 + table.insert(lines, Line:new({ { decoration }, { line } })) + end + if #no_change_lines > 3 then table.insert(lines, Line:new({ { decoration }, { "..." } })) end + end + return lines +end + +---@param content any +---@param decoration string | nil +---@param truncate boolean | nil +function M.get_content_lines(content, decoration, truncate) + local lines = {} + local content_obj = content + if type(content) == "string" then + local ok, content_obj_ = pcall(vim.json.decode, content) + if ok then content_obj = content_obj_ end + end + if type(content_obj) == "table" then + if islist(content_obj) then + local line_count = 0 + local all_lines = {} + for _, content_item in ipairs(content_obj) do + if type(content_item) == "string" then + local lines_ = text_to_lines(content_item, decoration) + all_lines = vim.list_extend(all_lines, lines_) + end + end + for idx, line in ipairs(all_lines) do + if truncate and line_count > 3 then + table.insert( + lines, + Line:new({ + { decoration }, + { string.format("... (Result truncated, remaining %d lines not shown)", #all_lines - idx + 1) }, + }) + ) + break + end + line_count = line_count + 1 + table.insert(lines, line) + end + end + if type(content_obj.content) == "string" then + local line_count = 0 + local lines_ = text_to_lines(content_obj.content, decoration) + for idx, line in ipairs(lines_) do + if truncate and line_count > 3 then + table.insert( + lines, + Line:new({ + { decoration }, + { string.format("... (Result truncated, remaining %d lines not shown)", #lines_ - idx + 1) }, + }) + ) + break + end + line_count = line_count + 1 + table.insert(lines, line) + end + end + end + if type(content_obj) == "string" then + local lines_ = text_to_lines(content_obj, decoration) + local line_count = 0 + for idx, line in ipairs(lines_) do + if truncate and line_count > 3 then + table.insert( + lines, + Line:new({ + { decoration }, + { string.format("... (Result truncated, remaining %d lines not shown)", #lines_ - idx + 1) }, + }) + ) + break + end + line_count = line_count + 1 + table.insert(lines, line) + end + end + if islist(content) then + for _, content_item in ipairs(content) do + local line_count = 0 + if content_item.type == "content" then + if content_item.content.type == "text" then + local lines_ = text_to_lines(content_item.content.text, decoration) + for idx, line in ipairs(lines_) do + if truncate and line_count > 3 then + table.insert( + lines, + Line:new({ + { decoration }, + { string.format("... (Result truncated, remaining %d lines not shown)", #lines_ - idx + 1) }, + }) + ) + break + end + line_count = line_count + 1 + table.insert(lines, line) + end + end + elseif content_item.type == "diff" then + table.insert(lines, Line:new({ { decoration }, { "Path: " .. content_item.path } })) + local lines_ = M.get_diff_lines(content_item.oldText, content_item.newText, decoration, truncate) + lines = vim.list_extend(lines, lines_) + end + end + end + return lines +end + ---Converts a tool invocation into format suitable for UI ---@param item AvanteLLMMessageContentItem +---@param message avante.HistoryMessage ---@param messages avante.HistoryMessage[] ----@param logs string[]|nil ---@return avante.ui.Line[] -local function tool_to_lines(item, messages, logs) +local function tool_to_lines(item, message, messages) + -- local logs = message.tool_use_logs local lines = {} + local tool_name = item.name + + local rest_input_text_lines = {} + + if message.displayed_tool_name then + tool_name = message.displayed_tool_name + else + if item.input and type(item.input) == "table" then + local param + if type(item.input.path) == "string" then param = item.input.path end + if type(item.input.rel_path) == "string" then param = item.input.rel_path end + if type(item.input.filepath) == "string" then param = item.input.filepath end + if type(item.input.query) == "string" then param = item.input.query end + if type(item.input.pattern) == "string" then param = item.input.pattern end + if type(item.input.command) == "string" then + param = item.input.command + local pieces = vim.split(param, "\n") + if #pieces > 1 then param = pieces[1] .. "..." end + end + if param then tool_name = item.name .. "(" .. param .. ")" end + end + end + local result = Helpers.get_tool_result(item.id, messages) local state if not result then @@ -78,11 +278,47 @@ local function tool_to_lines(item, messages, logs) lines, Line:new({ { "╭─ " }, - { " " .. item.name .. " ", STATE_TO_HL[state] }, + { " " .. tool_name .. " ", STATE_TO_HL[state] }, { " " .. state }, }) ) - if logs then vim.list_extend(lines, tool_logs_to_lines(item.name, logs)) end + -- if logs then vim.list_extend(lines, tool_logs_to_lines(item.name, logs)) end + local decoration = "│ " + if rest_input_text_lines and #rest_input_text_lines > 0 then + local lines_ = text_to_lines(table.concat(rest_input_text_lines, "\n"), decoration) + local line_count = 0 + for idx, line in ipairs(lines_) do + if line_count > 3 then + table.insert( + lines, + Line:new({ + { decoration }, + { string.format("... (Input truncated, remaining %d lines not shown)", #lines_ - idx + 1) }, + }) + ) + break + end + line_count = line_count + 1 + table.insert(lines, line) + end + table.insert(lines, Line:new({ { decoration }, { "" } })) + end + if item.input and type(item.input) == "table" then + if type(item.input.old_str) == "string" and type(item.input.new_str) == "string" then + local diff_lines = M.get_diff_lines(item.input.old_str, item.input.new_str, decoration, true) + vim.list_extend(lines, diff_lines) + end + end + if result and result.content then + local result_content = result.content + if result_content then + local content_lines = M.get_content_lines(result_content, decoration, true) + vim.list_extend(lines, content_lines) + end + end + if #lines <= 1 then table.insert(lines, Line:new({ { decoration }, { "completed" } })) end + local last_line = lines[#lines] + last_line.sections[1][1] = "╰─ " return lines end @@ -96,7 +332,7 @@ local function message_content_item_to_lines(item, message, messages) return text_to_lines(item) elseif type(item) == "table" then if item.type == "thinking" or item.type == "redacted_thinking" then - return thinking_to_lines(item.thinking or item.data or "") + return thinking_to_lines(item) elseif item.type == "text" then return text_to_lines(item.text) elseif item.type == "image" then @@ -116,7 +352,9 @@ local function message_content_item_to_lines(item, message, messages) end end - return tool_to_lines(item, messages, message.tool_use_logs) + local lines = tool_to_lines(item, message, messages) + if message.tool_use_log_lines then lines = vim.list_extend(lines, message.tool_use_log_lines) end + return lines end end return {} diff --git a/lua/avante/libs/acp_client.lua b/lua/avante/libs/acp_client.lua new file mode 100644 index 0000000..5eb16d6 --- /dev/null +++ b/lua/avante/libs/acp_client.lua @@ -0,0 +1,891 @@ +---@class ClientCapabilities +---@field fs FileSystemCapability + +---@class FileSystemCapability +---@field readTextFile boolean +---@field writeTextFile boolean + +---@class AgentCapabilities +---@field loadSession boolean +---@field promptCapabilities PromptCapabilities + +---@class PromptCapabilities +---@field image boolean +---@field audio boolean +---@field embeddedContext boolean + +---@class AuthMethod +---@field id string +---@field name string +---@field description string|nil + +---@class McpServer +---@field name string +---@field command string +---@field args string[] +---@field env EnvVariable[] + +---@class EnvVariable +---@field name string +---@field value string + +---@alias StopReason "end_turn" | "max_tokens" | "max_turn_requests" | "refusal" | "cancelled" + +---@alias ToolKind "read" | "edit" | "delete" | "move" | "search" | "execute" | "think" | "fetch" | "other" + +---@alias ToolCallStatus "pending" | "in_progress" | "completed" | "failed" + +---@alias PlanEntryStatus "pending" | "in_progress" | "completed" + +---@alias PlanEntryPriority "high" | "medium" | "low" + +---@class ContentBlock +---@field type "text" | "image" | "audio" | "resource_link" | "resource" +---@field annotations Annotations|nil + +---@class TextContent : ContentBlock +---@field type "text" +---@field text string + +---@class ImageContent : ContentBlock +---@field type "image" +---@field data string +---@field mimeType string +---@field uri string|nil + +---@class AudioContent : ContentBlock +---@field type "audio" +---@field data string +---@field mimeType string + +---@class ResourceLinkContent : ContentBlock +---@field type "resource_link" +---@field uri string +---@field name string +---@field description string|nil +---@field mimeType string|nil +---@field size number|nil +---@field title string|nil + +---@class ResourceContent : ContentBlock +---@field type "resource" +---@field resource EmbeddedResource + +---@class EmbeddedResource +---@field uri string +---@field text string|nil +---@field blob string|nil +---@field mimeType string|nil + +---@class Annotations +---@field audience any[]|nil +---@field lastModified string|nil +---@field priority number|nil + +---@class ToolCall +---@field toolCallId string +---@field title string +---@field kind ToolKind +---@field status ToolCallStatus +---@field content ToolCallContent[] +---@field locations ToolCallLocation[] +---@field rawInput table +---@field rawOutput table + +---@class ToolCallContent +---@field type "content" | "diff" + +---@class ToolCallContentBlock : ToolCallContent +---@field type "content" +---@field content ContentBlock + +---@class ToolCallDiff : ToolCallContent +---@field type "diff" +---@field path string +---@field oldText string|nil +---@field newText string + +---@class ToolCallLocation +---@field path string +---@field line number|nil + +---@class PlanEntry +---@field content string +---@field priority PlanEntryPriority +---@field status PlanEntryStatus + +---@class Plan +---@field entries PlanEntry[] + +---@class SessionUpdate +---@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan" + +---@class UserMessageChunk : SessionUpdate +---@field sessionUpdate "user_message_chunk" +---@field content ContentBlock + +---@class AgentMessageChunk : SessionUpdate +---@field sessionUpdate "agent_message_chunk" +---@field content ContentBlock + +---@class AgentThoughtChunk : SessionUpdate +---@field sessionUpdate "agent_thought_chunk" +---@field content ContentBlock + +---@class ToolCallUpdate : SessionUpdate +---@field sessionUpdate "tool_call" | "tool_call_update" +---@field toolCallId string +---@field title string|nil +---@field kind ToolKind|nil +---@field status ToolCallStatus|nil +---@field content ToolCallContent[]|nil +---@field locations ToolCallLocation[]|nil +---@field rawInput table|nil +---@field rawOutput table|nil + +---@class PlanUpdate : SessionUpdate +---@field sessionUpdate "plan" +---@field entries PlanEntry[] + +---@class PermissionOption +---@field optionId string +---@field name string +---@field kind "allow_once" | "allow_always" | "reject_once" | "reject_always" + +---@class RequestPermissionOutcome +---@field outcome "cancelled" | "selected" +---@field optionId string|nil + +---@class ACPTransport +---@field send function +---@field start function +---@field stop function + +---@alias ACPConnectionState "disconnected" | "connecting" | "connected" | "initializing" | "ready" | "error" + +---@class ACPError +---@field code number +---@field message string +---@field data any|nil + +---@class ACPClient +---@field protocol_version number +---@field capabilities ClientCapabilities +---@field agent_capabilities AgentCapabilities|nil +---@field config ACPConfig +---@field callbacks table +local ACPClient = {} + +-- ACP Error codes +ACPClient.ERROR_CODES = { + TRANSPORT_ERROR = -32000, + PROTOCOL_ERROR = -32001, + TIMEOUT_ERROR = -32002, + AUTH_REQUIRED = -32003, + SESSION_NOT_FOUND = -32004, + PERMISSION_DENIED = -32005, + INVALID_REQUEST = -32006, +} + +---@class ACPHandlers +---@field on_session_update? function +---@field on_request_permission? function +---@field on_read_file? function +---@field on_write_file? function +---@field on_error? function + +---@class ACPConfig +---@field transport_type "stdio" | "websocket" | "tcp" +---@field command? string Command to spawn agent (for stdio) +---@field args? string[] Arguments for agent command +---@field env? table Environment variables +---@field host? string Host for tcp/websocket +---@field port? number Port for tcp/websocket +---@field timeout? number Request timeout in milliseconds +---@field reconnect? boolean Enable auto-reconnect +---@field max_reconnect_attempts? number Maximum reconnection attempts +---@field heartbeat_interval? number Heartbeat interval in milliseconds +---@field auth_method? string Authentication method +---@field handlers? ACPHandlers +---@field on_state_change? fun(new_state: ACPConnectionState, old_state: ACPConnectionState) + +---Create a new ACP client instance +---@param config ACPConfig +---@return ACPClient +function ACPClient:new(config) + local client = setmetatable({ + id_counter = 0, + protocol_version = 1, + capabilities = { + fs = { + readTextFile = true, + writeTextFile = true, + }, + }, + pending_responses = {}, + callbacks = {}, + transport = nil, + config = config or {}, + state = "disconnected", + reconnect_count = 0, + heartbeat_timer = nil, + }, { __index = self }) + + client:_setup_transport() + return client +end + +---Setup transport layer +function ACPClient:_setup_transport() + local transport_type = self.config.transport_type or "stdio" + + if transport_type == "stdio" then + self.transport = self:_create_stdio_transport() + elseif transport_type == "websocket" then + self.transport = self:_create_websocket_transport() + elseif transport_type == "tcp" then + self.transport = self:_create_tcp_transport() + else + error("Unsupported transport type: " .. transport_type) + end +end + +---Set connection state +---@param state ACPConnectionState +function ACPClient:_set_state(state) + local old_state = self.state + self.state = state + + if self.config.on_state_change then self.config.on_state_change(state, old_state) end +end + +---Create error object +---@param code number +---@param message string +---@param data any? +---@return ACPError +function ACPClient:_create_error(code, message, data) + return { + code = code, + message = message, + data = data, + } +end + +---Create stdio transport layer +function ACPClient:_create_stdio_transport() + local uv = vim.loop + local transport = { + stdin = nil, + stdout = nil, + process = nil, + } + + function transport.send(transport_self, data) + if transport_self.stdin and not transport_self.stdin:is_closing() then + transport_self.stdin:write(data .. "\n") + return true + end + return false + end + + function transport.start(transport_self, on_message) + self:_set_state("connecting") + + local stdin = uv.new_pipe(false) + local stdout = uv.new_pipe(false) + local stderr = uv.new_pipe(false) + + if not stdin or not stdout or not stderr then + self:_set_state("error") + error("Failed to create pipes for ACP agent") + end + + local args = vim.deepcopy(self.config.args or {}) + local env = self.config.env + + -- Start with system environment and override with config env + local final_env = {} + + local path = vim.fn.getenv("PATH") + if path then final_env[#final_env + 1] = "PATH=" .. path end + + if env then + for k, v in pairs(env) do + final_env[#final_env + 1] = k .. "=" .. v + end + end + + ---@diagnostic disable-next-line: missing-fields + local handle = uv.spawn(self.config.command, { + args = args, + env = final_env, + stdio = { stdin, stdout, stderr }, + }, function(code, signal) + vim.print("ACP agent exited with code " .. code .. " and signal " .. signal) + self:_set_state("disconnected") + + if transport_self.process then + transport_self.process:close() + transport_self.process = nil + end + + -- Handle auto-reconnect + if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then + self.reconnect_count = self.reconnect_count + 1 + vim.defer_fn(function() + if self.state == "disconnected" then self:connect() end + end, 2000) -- Wait 2 seconds before reconnecting + end + end) + + if not handle then + self:_set_state("error") + error("Failed to spawn ACP agent process") + end + + transport_self.process = handle + transport_self.stdin = stdin + transport_self.stdout = stdout + + self:_set_state("connected") + + -- Read stdout + local buffer = "" + stdout:read_start(function(err, data) + -- if data then vim.print("ACP stdout: " .. vim.inspect(data)) end + if err then + vim.notify("ACP stdout error: " .. err, vim.log.levels.ERROR) + self:_set_state("error") + return + end + + if data then + buffer = buffer .. data + + -- Split on newlines and process complete JSON-RPC messages + local lines = vim.split(buffer, "\n", { plain = true }) + buffer = lines[#lines] -- Keep incomplete line in buffer + + for i = 1, #lines - 1 do + local line = vim.trim(lines[i]) + if line ~= "" then + local ok, message = pcall(vim.json.decode, line) + if ok then + on_message(message) + else + vim.schedule( + function() vim.notify("Failed to parse JSON-RPC message: " .. line, vim.log.levels.WARN) end + ) + end + end + end + end + end) + + -- Read stderr for debugging + stderr:read_start(function(_, data) + if data then vim.schedule(function() vim.notify("ACP stderr: " .. data, vim.log.levels.DEBUG) end) end + end) + end + + function transport.stop(transport_self) + if transport_self.process then + transport_self.process:close() + transport_self.process = nil + end + if transport_self.stdin then + transport_self.stdin:close() + transport_self.stdin = nil + end + if transport_self.stdout then + transport_self.stdout:close() + transport_self.stdout = nil + end + self:_set_state("disconnected") + end + + return transport +end + +---Create WebSocket transport layer (placeholder) +function ACPClient:_create_websocket_transport() error("WebSocket transport not implemented yet") end + +---Create TCP transport layer (placeholder) +function ACPClient:_create_tcp_transport() error("TCP transport not implemented yet") end + +---Generate next request ID +---@return number +function ACPClient:_next_id() + self.id_counter = self.id_counter + 1 + return self.id_counter +end + +---Send JSON-RPC request +---@param method string +---@param params table? +---@param callback? fun(result: table|nil, err: ACPError|nil) +---@return table|nil result +---@return ACPError|nil err +function ACPClient:_send_request(method, params, callback) + local id = self:_next_id() + local message = { + jsonrpc = "2.0", + id = id, + method = method, + params = params or {}, + } + + if callback then self.callbacks[id] = callback end + + local data = vim.json.encode(message) .. "\n" + if not self.transport:send(data) then return nil end + + if not callback then return self:_wait_response(id) end +end + +function ACPClient:_wait_response(id) + local start_time = vim.loop.now() + local timeout = self.config.timeout or 100000 + + while vim.loop.now() - start_time < timeout do + vim.wait(10) + + if self.pending_responses[id] then + local result, err = unpack(self.pending_responses[id]) + self.pending_responses[id] = nil + return result, err + end + end + + return nil, self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for response") +end + +---Send JSON-RPC notification +---@param method string +---@param params table? +function ACPClient:_send_notification(method, params) + local message = { + jsonrpc = "2.0", + method = method, + params = params or {}, + } + + local data = vim.json.encode(message) .. "\n" + self.transport:send(data) +end + +---Send JSON-RPC result +---@param id number +---@param result table +---@return nil +function ACPClient:_send_result(id, result) + local message = { jsonrpc = "2.0", id = id, result = result } + + local data = vim.json.encode(message) .. "\n" + -- vim.print("Sending result: " .. data) + self.transport:send(data) +end + +---Send JSON-RPC error +---@param id number +---@param message string +---@param code? number +---@return nil +function ACPClient:_send_error(id, message, code) + code = code or self.ERROR_CODES.TRANSPORT_ERROR + local msg = { jsonrpc = "2.0", id = id, error = { code = code, message = message } } + + local data = vim.json.encode(msg) .. "\n" + self.transport:send(data) +end + +---Handle received message +---@param message table +function ACPClient:_handle_message(message) + -- vim.print("Received message: " .. vim.inspect(message)) + -- Check if this is a notification (has method but no id, or has both method and id for notifications) + if message.method and not message.result and not message.error then + -- This is a notification + self:_handle_notification(message.id, message.method, message.params) + elseif message.id and (message.result or message.error) then + local callback = self.callbacks[message.id] + if callback then + callback(message.result, message.error) + self.callbacks[message.id] = nil + else + self.pending_responses[message.id] = { message.result, message.error } + end + else + -- Unknown message type + vim.notify("Unknown message type: " .. vim.inspect(message), vim.log.levels.WARN) + end +end + +---Handle notification +---@param method string +---@param params table +function ACPClient:_handle_notification(message_id, method, params) + if method == "session/update" then + self:_handle_session_update(params) + elseif method == "session/request_permission" then + self:_handle_request_permission(message_id, params) + elseif method == "fs/read_text_file" then + self:_handle_read_text_file(message_id, params) + elseif method == "fs/write_text_file" then + self:_handle_write_text_file(message_id, params) + else + vim.notify("Unknown notification method: " .. method, vim.log.levels.WARN) + end +end + +---Handle session update notification +---@param params table +function ACPClient:_handle_session_update(params) + local session_id = params.sessionId + local update = params.update + + if not session_id then + vim.notify("Received session/update without sessionId", vim.log.levels.WARN) + return + end + + if not update then + vim.notify("Received session/update without update data", vim.log.levels.WARN) + return + end + + if self.config.handlers and self.config.handlers.on_session_update then + self.config.handlers.on_session_update(update) + end +end + +---Handle permission request notification +---@param message_id number +---@param params table +function ACPClient:_handle_request_permission(message_id, params) + local session_id = params.sessionId + local tool_call = params.toolCall + local options = params.options + + if not session_id or not tool_call then return end + + if self.config.handlers and self.config.handlers.on_request_permission then + self.config.handlers.on_request_permission( + tool_call, + options, + function(option_id) + self:_send_result(message_id, { + outcome = { + outcome = "selected", + optionId = option_id, + }, + }) + end + ) + end +end + +---Handle fs/read_text_file requests +---@param message_id number +---@param params table +function ACPClient:_handle_read_text_file(message_id, params) + local session_id = params.sessionId + local path = params.path + + if not session_id or not path then return end + + if self.config.handlers and self.config.handlers.on_read_file then + local content = self.config.handlers.on_read_file(path) + self:_send_result(message_id, { content = content }) + end +end + +---Handle fs/write_text_file requests +---@param message_id number +---@param params table +function ACPClient:_handle_write_text_file(message_id, params) + local session_id = params.sessionId + local path = params.path + local content = params.content + + if not session_id or not path or not content then return end + + if self.config.handlers and self.config.handlers.on_write_file then + local error = self.config.handlers.on_write_file(path, content) + self:_send_result(message_id, error == nil and vim.NIL or error) + end +end + +---Start client +function ACPClient:connect() + if self.state ~= "disconnected" then return end + + self.transport:start(function(message) self:_handle_message(message) end) + + self:initialize() +end + +---Stop client +function ACPClient:stop() + self.transport:stop() + + self.pending_responses = {} + self.reconnect_count = 0 +end + +---Initialize protocol connection +function ACPClient:initialize() + if self.state ~= "connected" then + local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected") + return error + end + + self:_set_state("initializing") + + local result = self:_send_request("initialize", { + protocolVersion = self.protocol_version, + clientCapabilities = self.capabilities, + }) + + if not result then + self:_set_state("error") + vim.notify("Failed to initialize", vim.log.levels.ERROR) + return + end + + -- Update protocol version and capabilities + self.protocol_version = result.protocolVersion + self.agent_capabilities = result.agentCapabilities + self.auth_methods = result.authMethods or {} + + -- Check if we need to authenticate + local auth_method = self.config.auth_method + + if auth_method then + vim.print("Authenticating with method " .. auth_method) + self:authenticate(auth_method) + self:_set_state("ready") + else + vim.print("No authentication method found or specified") + self:_set_state("ready") + end +end + +---Authentication (if required) +---@param method_id string +function ACPClient:authenticate(method_id) + return self:_send_request("authenticate", { + methodId = method_id, + }) +end + +---Create new session +---@param cwd string +---@param mcp_servers table[]? +---@return string|nil session_id +---@return ACPError|nil err +function ACPClient:create_session(cwd, mcp_servers) + local result, err = self:_send_request("session/new", { + cwd = cwd, + mcpServers = mcp_servers or {}, + }) + if err then + vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR) + return nil, err + end + if not result then + err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result") + return nil, err + end + return result.sessionId, nil +end + +---Load existing session +---@param session_id string +---@param cwd string +---@param mcp_servers table[]? +---@return table|nil result +function ACPClient:load_session(session_id, cwd, mcp_servers) + if not self.agent_capabilities or not self.agent_capabilities.loadSession then + vim.notify("Agent does not support loading sessions", vim.log.levels.WARN) + return + end + + return self:_send_request("session/load", { + sessionId = session_id, + cwd = cwd, + mcpServers = mcp_servers or {}, + }) +end + +---Send prompt +---@param session_id string +---@param prompt table[] +---@param callback? fun(result: table|nil, err: ACPError|nil) +function ACPClient:send_prompt(session_id, prompt, callback) + local params = { + sessionId = session_id, + prompt = prompt, + } + return self:_send_request("session/prompt", params, callback) +end + +---Cancel session +---@param session_id string +function ACPClient:cancel_session(session_id) + self:_send_notification("session/cancel", { + sessionId = session_id, + }) +end + +---Helper function: Create text content block +---@param text string +---@param annotations table? +---@return table +function ACPClient:create_text_content(text, annotations) + return { + type = "text", + text = text, + annotations = annotations, + } +end + +---Helper function: Create image content block +---@param data string Base64 encoded image data +---@param mime_type string +---@param uri string? +---@param annotations table? +---@return table +function ACPClient:create_image_content(data, mime_type, uri, annotations) + return { + type = "image", + data = data, + mimeType = mime_type, + uri = uri, + annotations = annotations, + } +end + +---Helper function: Create audio content block +---@param data string Base64 encoded audio data +---@param mime_type string +---@param annotations table? +---@return table +function ACPClient:create_audio_content(data, mime_type, annotations) + return { + type = "audio", + data = data, + mimeType = mime_type, + annotations = annotations, + } +end + +---Helper function: Create resource link content block +---@param uri string +---@param name string +---@param description string? +---@param mime_type string? +---@param size number? +---@param title string? +---@param annotations table? +---@return table +function ACPClient:create_resource_link_content(uri, name, description, mime_type, size, title, annotations) + return { + type = "resource_link", + uri = uri, + name = name, + description = description, + mimeType = mime_type, + size = size, + title = title, + annotations = annotations, + } +end + +---Helper function: Create embedded resource content block +---@param resource table +---@param annotations table? +---@return table +function ACPClient:create_resource_content(resource, annotations) + return { + type = "resource", + resource = resource, + annotations = annotations, + } +end + +---Helper function: Create text resource +---@param uri string +---@param text string +---@param mime_type string? +---@return table +function ACPClient:create_text_resource(uri, text, mime_type) + return { + uri = uri, + text = text, + mimeType = mime_type, + } +end + +---Helper function: Create binary resource +---@param uri string +---@param blob string Base64 encoded binary data +---@param mime_type string? +---@return table +function ACPClient:create_blob_resource(uri, blob, mime_type) + return { + uri = uri, + blob = blob, + mimeType = mime_type, + } +end + +---Convenience method: Check if client is ready +---@return boolean +function ACPClient:is_ready() return self.state == "ready" end + +---Convenience method: Check if client is connected +---@return boolean +function ACPClient:is_connected() return self.state ~= "disconnected" and self.state ~= "error" end + +---Convenience method: Get current state +---@return ACPConnectionState +function ACPClient:get_state() return self.state end + +---Convenience method: Wait for client to be ready +---@param callback function +---@param timeout number? Timeout in milliseconds +function ACPClient:wait_ready(callback, timeout) + if self:is_ready() then + callback(nil) + return + end + + local timeout_ms = timeout or 10000 -- 10 seconds default + local start_time = vim.loop.now() + + local function check_ready() + if self:is_ready() then + callback(nil) + elseif self.state == "error" then + callback(self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Client entered error state while waiting")) + elseif vim.loop.now() - start_time > timeout_ms then + callback(self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for client to be ready")) + else + vim.defer_fn(check_ready, 100) -- Check every 100ms + end + end + + check_ready() +end + +---Convenience method: Send simple text prompt +---@param session_id string +---@param text string +function ACPClient:send_text_prompt(session_id, text) + local prompt = { self:create_text_content(text) } + self:send_prompt(session_id, prompt) +end + +return ACPClient diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 033f2f5..c2eb752 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -3,6 +3,7 @@ local fn = vim.fn local uv = vim.uv local curl = require("plenary.curl") +local ACPClient = require("avante.libs.acp_client") local Utils = require("avante.utils") local Prompts = require("avante.utils.prompts") @@ -13,6 +14,7 @@ local Providers = require("avante.providers") local LLMToolHelpers = require("avante.llm_tools.helpers") local LLMTools = require("avante.llm_tools") local History = require("avante.history") +local Selector = require("avante.ui.selector") ---@class avante.LLM local M = {} @@ -486,6 +488,7 @@ end ---@param opts AvanteGeneratePromptsOptions ---@return integer function M.calculate_tokens(opts) + if Config.acp_providers[Config.provider] then return 0 end local prompt_opts = M.generate_prompts(opts) local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt) for _, message in ipairs(prompt_opts.messages) do @@ -795,11 +798,339 @@ local function stop_retry_timer() end end +---@param opts AvanteLLMStreamOptions +function M._stream_acp(opts) + ---@type table + local tool_call_messages = {} + local acp_provider = Config.acp_providers[Config.provider] + local on_messages_add = function(messages) + if opts.on_messages_add then opts.on_messages_add(messages) end + vim.schedule(function() vim.cmd("redraw") end) + end + local function add_tool_call_message(update) + local message = History.Message:new("assistant", { + type = "tool_use", + id = update.toolCallId, + name = update.kind .. "(" .. update.title .. ")", + }) + if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end + tool_call_messages[update.toolCallId] = message + if update.rawInput then + local path = update.rawInput.path or update.rawInput.file_path + if path then + local relative_path = Utils.relative_path(path) + message.displayed_tool_name = update.title .. "(" .. relative_path .. ")" + end + local pattern = update.rawInput.pattern or update.rawInput.search + if pattern then message.displayed_tool_name = update.title .. "(" .. pattern .. ")" end + local command = update.rawInput.command or update.rawInput.command_line + if command then message.displayed_tool_name = update.title .. "(" .. command .. ")" end + local description = update.rawInput.description + if description then + message.tool_use_logs = message.tool_use_logs or {} + table.insert(message.tool_use_logs, description) + end + end + on_messages_add({ message }) + return message + end + local acp_client = opts.acp_client + if not acp_client then + local acp_config = vim.tbl_deep_extend("force", acp_provider, { + handlers = { + on_session_update = function(update) + if update.sessionUpdate == "plan" then + local todos = {} + for idx, entry in ipairs(update.entries) do + local status = "todo" + if entry.status == "in_progress" then status = "doing" end + if entry.status == "completed" then status = "done" end + ---@type avante.TODO + local todo = { + id = tostring(idx), + content = entry.content, + status = status, + priority = entry.priority, + } + table.insert(todos, todo) + end + vim.schedule(function() + if opts.update_todos then opts.update_todos(todos) end + end) + return + end + if update.sessionUpdate == "agent_message_chunk" then + if update.content.type == "text" then + local messages = opts.get_history_messages() + local last_message = messages[#messages] + if last_message and last_message.message.role == "assistant" then + local has_text = false + local content = last_message.message.content + if type(content) == "string" then + last_message.message.content = last_message.message.content .. update.content.text + has_text = true + elseif type(content) == "table" then + for idx, item in ipairs(content) do + if type(item) == "string" then + content[idx] = item .. update.content.text + has_text = true + end + if type(item) == "table" and item.type == "text" then + item.text = item.text .. update.content.text + has_text = true + end + end + end + if has_text then + on_messages_add({ last_message }) + return + end + end + local message = History.Message:new("assistant", update.content.text) + on_messages_add({ message }) + end + end + if update.sessionUpdate == "agent_thought_chunk" then + if update.content.type == "text" then + local message = History.Message:new("assistant", { + type = "thinking", + thinking = update.content.text, + }) + on_messages_add({ message }) + end + end + if update.sessionUpdate == "tool_call" then add_tool_call_message(update) end + if update.sessionUpdate == "tool_call_update" then + local tool_call_message = tool_call_messages[update.toolCallId] + if not tool_call_message then + tool_call_message = History.Message:new("assistant", { + type = "tool_use", + id = update.toolCallId, + name = "", + }) + local update_content = update.content + if type(update_content) == "table" then + for _, item in ipairs(update_content) do + if item.path then + local relative_path = Utils.relative_path(item.path) + tool_call_message.displayed_tool_name = "Edit(" .. relative_path .. ")" + break + end + end + end + if not tool_call_message.displayed_tool_name then + tool_call_message.displayed_tool_name = update.toolCallId + end + end + tool_call_message.tool_use_logs = tool_call_message.tool_use_logs or {} + tool_call_message.tool_use_log_lines = tool_call_message.tool_use_log_lines or {} + local tool_result_message + if update.status == "pending" or update.status == "in_progress" then + tool_call_message.is_calling = true + tool_call_message.state = "generating" + else + tool_call_message.is_calling = false + tool_call_message.state = "generated" + tool_result_message = History.Message:new("assistant", { + type = "tool_result", + tool_use_id = update.toolCallId, + content = update.content, + is_error = update.status == "failed", + is_user_declined = update.status == "cancelled", + }) + end + local messages = { tool_call_message } + if tool_result_message then table.insert(messages, tool_result_message) end + on_messages_add(messages) + end + end, + on_request_permission = function(tool_call, options, callback) + local message = add_tool_call_message(tool_call) + local items = vim + .iter(options) + :map( + function(item) + return { + id = item.optionId, + title = item.name, + } + end + ) + :totable() + local default_item = vim.iter(items):find(function(item) return item.id == options[1].optionId end) + + local function on_select(item_ids) + if not item_ids then return end + local choice = vim.iter(items):find(function(item) return item.id == item_ids[1] end) + if not choice then return end + Utils.debug("on_select", choice.id) + callback(choice.id) + end + + local selector = Selector:new({ + title = message.displayed_tool_name or message.message.content[1].name, + items = items, + default_item_id = default_item and default_item.name or nil, + provider = Config.selector.provider, + provider_opts = Config.selector.provider_opts, + on_select = on_select, + get_preview_content = function(_) + local file_content = "" + local filetype = "text" + local content = tool_call.content + if type(content) == "table" then + for _, item in ipairs(content) do + if item.type == "content" then + if type(item.content) == "table" then + if item.content.type == "text" then + file_content = file_content .. item.content.text .. "\n\n" + end + end + end + if item.type == "diff" then + local unified_diff = Utils.get_unified_diff(item.oldText, item.newText, { algorithm = "myers" }) + local result = "--- a/" .. item.path .. "\n+++ b/" .. item.path .. "\n" .. unified_diff .. "\n\n" + filetype = "diff" + file_content = file_content .. result + end + end + end + return file_content, filetype + end, + }) + + vim.schedule(function() selector:open() end) + end, + on_read_file = function(path, line, limit) + local abs_path = Utils.to_absolute_path(path) + local lines = Utils.read_file_from_buf_or_disk(abs_path) + lines = lines or {} + if line ~= nil and limit ~= nil then lines = vim.list_slice(lines, line, line + limit) end + return table.concat(lines, "\n") + end, + on_write_file = function(path, content) + local abs_path = Utils.to_absolute_path(path) + local file = io.open(abs_path, "w") + if file then + file:write(content) + file:close() + return nil + end + return "Failed to write file: " .. abs_path + end, + }, + }) + acp_client = ACPClient:new(acp_config) + acp_client:connect() + opts.on_save_acp_client(acp_client) + end + local session_id = opts.acp_session_id + if not session_id then + local project_root = Utils.root.get() + local session_id_, err = acp_client:create_session(project_root, {}) + if err then + opts.on_stop({ reason = "error", error = err }) + return + end + if not session_id_ then + opts.on_stop({ reason = "error", error = "Failed to create session" }) + return + end + session_id = session_id_ + opts.on_save_acp_session_id(session_id) + end + local prompt = {} + local history_messages = opts.history_messages or {} + if opts.acp_session_id then + for i = #history_messages, 1, -1 do + local message = history_messages[i] + if message.message.role == "user" then + local content = message.message.content + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" then + table.insert(prompt, { + type = "text", + text = item, + }) + elseif type(item) == "table" and item.type == "text" then + table.insert(prompt, { + type = "text", + text = item.text, + }) + end + end + elseif type(content) == "string" then + table.insert(prompt, { + type = "text", + text = content, + }) + end + break + end + end + else + for _, message in ipairs(history_messages) do + if message.message.role == "user" then + local content = message.message.content + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" then + table.insert(prompt, { + type = "text", + text = item, + }) + elseif type(item) == "table" and item.type == "text" then + table.insert(prompt, { + type = "text", + text = item.text, + }) + end + end + else + table.insert(prompt, { + type = "text", + text = content, + }) + end + end + end + end + if opts.selected_filepaths then + for _, filepath in ipairs(opts.selected_filepaths) do + local lines, error = Utils.read_file_from_buf_or_disk(filepath) + if error ~= nil then + Utils.error("error reading file: " .. error) + else + local abs_path = Utils.to_absolute_path(filepath) + local content = table.concat(lines or {}, "\n") + local filetype = Utils.get_filetype(filepath) + local prompt_item = acp_client:create_resource_content({ + uri = "file://" .. abs_path, + mimeType = "text/x-" .. filetype, + text = content, + }, nil) + table.insert(prompt, prompt_item) + end + end + end + acp_client:send_prompt(session_id, prompt, function(_, err_) + if err_ then + opts.on_stop({ reason = "error", error = err_ }) + return + end + opts.on_stop({ reason = "complete" }) + end) +end + ---@param opts AvanteLLMStreamOptions function M._stream(opts) -- Reset the cancellation flag at the start of a new request if LLMToolHelpers then LLMToolHelpers.is_cancelled = false end + local acp_provider = Config.acp_providers[Config.provider] + if acp_provider then return M._stream_acp(opts) end + local provider = opts.provider or Providers[Config.provider] opts.session_ctx = opts.session_ctx or {} diff --git a/lua/avante/llm_tools/replace_in_file.lua b/lua/avante/llm_tools/replace_in_file.lua index bff8313..d1802ff 100644 --- a/lua/avante/llm_tools/replace_in_file.lua +++ b/lua/avante/llm_tools/replace_in_file.lua @@ -18,9 +18,7 @@ M.description = M.support_streaming = true -function M.enabled() - return require("avante.config").mode == "agentic" and not require("avante.config").behaviour.enable_fastapply -end +function M.enabled() return false end ---@type AvanteLLMToolParam M.param = { @@ -184,17 +182,7 @@ function M.func(input, opts) -- Utils.debug("diff", diff) local err = [[No diff blocks found. -Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE blocks are in the correct order. - -For example: - ``` - ------- SEARCH - [exact content to find] - ======= - [new content to replace with] - +++++++ REPLACE - ``` -]] +Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE blocks are in the correct order.]] return false, err end diff --git a/lua/avante/llm_tools/str_replace.lua b/lua/avante/llm_tools/str_replace.lua index 6714c05..358b9e8 100644 --- a/lua/avante/llm_tools/str_replace.lua +++ b/lua/avante/llm_tools/str_replace.lua @@ -8,7 +8,9 @@ M.name = "str_replace" M.description = "The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits." -function M.enabled() return false end +function M.enabled() + return require("avante.config").mode == "agentic" and not require("avante.config").behaviour.enable_fastapply +end ---@type AvanteLLMToolParam M.param = { diff --git a/lua/avante/providers/init.lua b/lua/avante/providers/init.lua index 4c2ee72..d210eb6 100644 --- a/lua/avante/providers/init.lua +++ b/lua/avante/providers/init.lua @@ -203,6 +203,8 @@ M = setmetatable(M, { function M.setup() vim.g.avante_login = false + if Config.acp_providers[Config.provider] then return end + ---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor local provider = M[Config.provider] diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index b6289ec..d10aaf8 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -73,6 +73,8 @@ Sidebar.__index = Sidebar ---@field input_hint_window integer | nil ---@field old_result_lines avante.ui.Line[] ---@field token_count integer | nil +---@field acp_client ACPClient | nil +---@field acp_session_id string | nil ---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage() function Sidebar:new(id) @@ -199,7 +201,6 @@ function Sidebar:set_code_winhl() if not Utils.is_valid_container(self.containers.result, true) then return end if Utils.should_hidden_border(self.code.winid, self.containers.result.winid) then - Utils.debug("setting winhl") local old_winhl = vim.wo[self.code.winid].winhl if self.code.old_winhl == nil then self.code.old_winhl = old_winhl @@ -1638,9 +1639,15 @@ end ---@param selected_code AvanteSelectedCode? ---@return string local function render_chat_record_prefix(timestamp, provider, model, request, selected_filepaths, selected_code) - provider = provider or "unknown" - model = model or "unknown" - local res = "- Datetime: " .. timestamp .. "\n" .. "- Model: " .. provider .. "/" .. model + local res + local acp_provider = Config.acp_providers[provider] + if acp_provider then + res = "- Datetime: " .. timestamp .. "\n" .. "- ACP: " .. provider + else + provider = provider or "unknown" + model = model or "unknown" + res = "- Datetime: " .. timestamp .. "\n" .. "- Model: " .. provider .. "/" .. model + end if selected_filepaths ~= nil and #selected_filepaths > 0 then res = res .. "\n- Selected files:" for _, path in ipairs(selected_filepaths) do @@ -1740,10 +1747,26 @@ local _message_to_lines_lru_cache = LRUCache:new(100) ---@return avante.ui.Line[] local function get_message_lines(message, messages, ctx) if message.state == "generating" or message.is_calling then return _get_message_lines(message, messages, ctx) end - local cached_lines = _message_to_lines_lru_cache:get(message.uuid) + local text_len = 0 + local content = message.message.content + if type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" then + text_len = text_len + #item + else + for _, subitem in ipairs(item) do + if type(subitem) == "string" then text_len = text_len + #subitem end + end + end + end + elseif type(content) == "string" then + text_len = #content + end + local cache_key = message.uuid .. ":" .. tostring(text_len) + local cached_lines = _message_to_lines_lru_cache:get(cache_key) if cached_lines then return cached_lines end local lines = _get_message_lines(message, messages, ctx) - _message_to_lines_lru_cache:set(message.uuid, lines) + _message_to_lines_lru_cache:set(cache_key, lines) return lines end @@ -1934,6 +1957,7 @@ function Sidebar:render_state() if self.current_state == "thinking" then hl = "AvanteStateSpinnerThinking" end if self.current_state == "compacting" then hl = "AvanteStateSpinnerCompacting" end local spinner_char = spinner_chars[self.state_spinner_idx] + if not spinner_char then spinner_char = spinner_chars[1] end self.state_spinner_idx = (self.state_spinner_idx % #spinner_chars) + 1 if self.current_state ~= "generating" @@ -2004,6 +2028,7 @@ function Sidebar:new_chat(args, cb) Path.history.save(self.code.bufnr, history) self:reload_chat_history() self.current_state = nil + self.acp_session_id = nil self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end }) if cb then cb(args) end vim.schedule(function() self:create_todos_container() end) @@ -2035,13 +2060,16 @@ function Sidebar:update_todos(todos) end ---@param messages avante.HistoryMessage | avante.HistoryMessage[] -function Sidebar:add_history_messages(messages) +---@param opts? {eager_update?: boolean} +function Sidebar:add_history_messages(messages, opts) local history_messages = History.get_history_messages(self.chat_history) messages = vim.islist(messages) and messages or { messages } for _, message in ipairs(messages) do if message.is_user_submission then message.provider = Config.provider - message.model = Config.get_provider_config(Config.provider).model + if not Config.acp_providers[Config.provider] then + message.model = Config.get_provider_config(Config.provider).model + end end local idx = nil for idx_, message_ in ipairs(history_messages) do @@ -2082,6 +2110,10 @@ function Sidebar:add_history_messages(messages) self.current_state = "generating" end end + if opts and opts.eager_update then + pcall(function() self:update_content("") end) + return + end xpcall(function() self:throttled_update_content("") end, function(err) Utils.debug("Failed to update content:", err) return nil @@ -2275,13 +2307,15 @@ function Sidebar:get_history_messages_for_api(opts) :totable() end - local tool_limit - if Providers[Config.provider].use_ReAct_prompt then - tool_limit = nil - else - tool_limit = 25 + if not Config.acp_providers[Config.provider] then + local tool_limit + if Providers[Config.provider].use_ReAct_prompt then + tool_limit = nil + else + tool_limit = 25 + end + messages = History.update_tool_invocation_history(messages, tool_limit, Config.behaviour.auto_check_diagnostics) end - messages = History.update_tool_invocation_history(messages, tool_limit, Config.behaviour.auto_check_diagnostics) end return messages @@ -2590,12 +2624,17 @@ function Sidebar:create_input_container() on_tool_log = on_tool_log, on_messages_add = on_messages_add, on_state_change = on_state_change, + acp_client = self.acp_client, + on_save_acp_client = function(client) self.acp_client = client end, + acp_session_id = self.acp_session_id, + on_save_acp_session_id = function(session_id) self.acp_session_id = session_id end, set_tool_use_store = set_tool_use_store, get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end, get_todos = function() local history = Path.history.load(self.code.bufnr) return history and history.todos or {} end, + update_todos = function(todos) self:update_todos(todos) end, session_ctx = {}, ---@param usage avante.LLMTokenUsage update_tokens_usage = function(usage) diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 236adfe..9b1ace9 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -98,6 +98,7 @@ vim.g.avante_login = vim.g.avante_login ---@field timestamp string ---@field state avante.HistoryMessageState ---@field uuid string | nil +---@field displayed_tool_name string | nil ---@field displayed_content string | nil ---@field visible boolean | nil ---@field is_context boolean | nil @@ -107,6 +108,7 @@ vim.g.avante_login = vim.g.avante_login ---@field selected_code AvanteSelectedCode | nil ---@field selected_filepaths string[] | nil ---@field tool_use_logs string[] | nil +---@field tool_use_log_lines avante.ui.Line[] | nil ---@field tool_use_store table | nil ---@field just_for_display boolean | nil ---@field is_dummy boolean | nil @@ -327,7 +329,7 @@ vim.g.avante_login = vim.g.avante_login --- ---@class AvanteProviderFunctor ---@field _model_list_cache table ----@field extra_headers function(table) -> table | table | nil +---@field extra_headers fun(table): table | table | nil ---@field support_prompt_caching boolean | nil ---@field role_map table<"user" | "assistant", string> ---@field parse_messages AvanteMessagesParser @@ -359,6 +361,12 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_response AvanteResponseParser ---@field build_bedrock_payload AvanteBedrockPayloadBuilder --- +---@class AvanteACPProvider +---@field command string +---@field args string[] +---@field env table +---@field auth_method string +--- ---@alias AvanteLlmMode avante.Mode | "editing" | "suggesting" --- ---@class AvanteSelectedCode @@ -382,6 +390,7 @@ vim.g.avante_login = vim.g.avante_login ---@field diagnostics string | nil ---@field history_messages avante.HistoryMessage[] | nil ---@field get_todos? fun(): avante.TODO[] +---@field update_todos? fun(todos: avante.TODO[]): nil ---@field memory string | nil ---@field get_tokens_usage? fun(): avante.LLMTokenUsage | nil --- @@ -405,6 +414,10 @@ vim.g.avante_login = vim.g.avante_login ---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking" | "compacting" | "compacted" | "initializing" | "initialized" --- ---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions +---@field acp_client? ACPClient +---@field on_save_acp_client? fun(client: ACPClient): nil +---@field acp_session_id? string +---@field on_save_acp_session_id? fun(session_id: string): nil ---@field on_start AvanteLLMStartCallback ---@field on_chunk? AvanteLLMChunkCallback ---@field on_stop AvanteLLMStopCallback diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 1e0e913..f79e686 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -1700,4 +1700,12 @@ function M.fix_diff(diff) return table.concat(the_final_diff_lines, "\n") end +function M.get_unified_diff(text1, text2, opts) + opts = opts or {} + opts.result_type = "unified" + opts.ctxlen = opts.ctxlen or 3 + + return vim.diff(text1, text2, opts) +end + return M diff --git a/lua/avante/utils/tokens.lua b/lua/avante/utils/tokens.lua index 97d82cc..6a88893 100644 --- a/lua/avante/utils/tokens.lua +++ b/lua/avante/utils/tokens.lua @@ -26,7 +26,7 @@ function Tokens.calculate_tokens(content) elseif type(item) == "table" and item.type == "image" then text = text .. item.source.data elseif type(item) == "table" and item.type == "tool_result" then - text = text .. item.content + if type(item.content) == "string" then text = text .. item.content end end end end