Files
avante.nvim/lua/avante/libs/acp_client.lua
2025-08-31 07:41:19 +08:00

892 lines
24 KiB
Lua

---@class ClientCapabilities
---@field fs FileSystemCapability
---@class FileSystemCapability
---@field readTextFile boolean
---@field writeTextFile boolean
---@class AgentCapabilities
---@field loadSession boolean
---@field promptCapabilities PromptCapabilities
---@class PromptCapabilities
---@field image boolean
---@field audio boolean
---@field embeddedContext boolean
---@class AuthMethod
---@field id string
---@field name string
---@field description string|nil
---@class McpServer
---@field name string
---@field command string
---@field args string[]
---@field env EnvVariable[]
---@class EnvVariable
---@field name string
---@field value string
---@alias StopReason "end_turn" | "max_tokens" | "max_turn_requests" | "refusal" | "cancelled"
---@alias ToolKind "read" | "edit" | "delete" | "move" | "search" | "execute" | "think" | "fetch" | "other"
---@alias ToolCallStatus "pending" | "in_progress" | "completed" | "failed"
---@alias PlanEntryStatus "pending" | "in_progress" | "completed"
---@alias PlanEntryPriority "high" | "medium" | "low"
---@class ContentBlock
---@field type "text" | "image" | "audio" | "resource_link" | "resource"
---@field annotations Annotations|nil
---@class TextContent : ContentBlock
---@field type "text"
---@field text string
---@class ImageContent : ContentBlock
---@field type "image"
---@field data string
---@field mimeType string
---@field uri string|nil
---@class AudioContent : ContentBlock
---@field type "audio"
---@field data string
---@field mimeType string
---@class ResourceLinkContent : ContentBlock
---@field type "resource_link"
---@field uri string
---@field name string
---@field description string|nil
---@field mimeType string|nil
---@field size number|nil
---@field title string|nil
---@class ResourceContent : ContentBlock
---@field type "resource"
---@field resource EmbeddedResource
---@class EmbeddedResource
---@field uri string
---@field text string|nil
---@field blob string|nil
---@field mimeType string|nil
---@class Annotations
---@field audience any[]|nil
---@field lastModified string|nil
---@field priority number|nil
---@class ToolCall
---@field toolCallId string
---@field title string
---@field kind ToolKind
---@field status ToolCallStatus
---@field content ToolCallContent[]
---@field locations ToolCallLocation[]
---@field rawInput table
---@field rawOutput table
---@class ToolCallContent
---@field type "content" | "diff"
---@class ToolCallContentBlock : ToolCallContent
---@field type "content"
---@field content ContentBlock
---@class ToolCallDiff : ToolCallContent
---@field type "diff"
---@field path string
---@field oldText string|nil
---@field newText string
---@class ToolCallLocation
---@field path string
---@field line number|nil
---@class PlanEntry
---@field content string
---@field priority PlanEntryPriority
---@field status PlanEntryStatus
---@class Plan
---@field entries PlanEntry[]
---@class SessionUpdate
---@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan"
---@class UserMessageChunk : SessionUpdate
---@field sessionUpdate "user_message_chunk"
---@field content ContentBlock
---@class AgentMessageChunk : SessionUpdate
---@field sessionUpdate "agent_message_chunk"
---@field content ContentBlock
---@class AgentThoughtChunk : SessionUpdate
---@field sessionUpdate "agent_thought_chunk"
---@field content ContentBlock
---@class ToolCallUpdate : SessionUpdate
---@field sessionUpdate "tool_call" | "tool_call_update"
---@field toolCallId string
---@field title string|nil
---@field kind ToolKind|nil
---@field status ToolCallStatus|nil
---@field content ToolCallContent[]|nil
---@field locations ToolCallLocation[]|nil
---@field rawInput table|nil
---@field rawOutput table|nil
---@class PlanUpdate : SessionUpdate
---@field sessionUpdate "plan"
---@field entries PlanEntry[]
---@class PermissionOption
---@field optionId string
---@field name string
---@field kind "allow_once" | "allow_always" | "reject_once" | "reject_always"
---@class RequestPermissionOutcome
---@field outcome "cancelled" | "selected"
---@field optionId string|nil
---@class ACPTransport
---@field send function
---@field start function
---@field stop function
---@alias ACPConnectionState "disconnected" | "connecting" | "connected" | "initializing" | "ready" | "error"
---@class ACPError
---@field code number
---@field message string
---@field data any|nil
---@class ACPClient
---@field protocol_version number
---@field capabilities ClientCapabilities
---@field agent_capabilities AgentCapabilities|nil
---@field config ACPConfig
---@field callbacks table<number, fun(result: table|nil, err: ACPError|nil)>
local ACPClient = {}
-- ACP Error codes
ACPClient.ERROR_CODES = {
TRANSPORT_ERROR = -32000,
PROTOCOL_ERROR = -32001,
TIMEOUT_ERROR = -32002,
AUTH_REQUIRED = -32003,
SESSION_NOT_FOUND = -32004,
PERMISSION_DENIED = -32005,
INVALID_REQUEST = -32006,
}
---@class ACPHandlers
---@field on_session_update? function
---@field on_request_permission? function
---@field on_read_file? function
---@field on_write_file? function
---@field on_error? function
---@class ACPConfig
---@field transport_type "stdio" | "websocket" | "tcp"
---@field command? string Command to spawn agent (for stdio)
---@field args? string[] Arguments for agent command
---@field env? table Environment variables
---@field host? string Host for tcp/websocket
---@field port? number Port for tcp/websocket
---@field timeout? number Request timeout in milliseconds
---@field reconnect? boolean Enable auto-reconnect
---@field max_reconnect_attempts? number Maximum reconnection attempts
---@field heartbeat_interval? number Heartbeat interval in milliseconds
---@field auth_method? string Authentication method
---@field handlers? ACPHandlers
---@field on_state_change? fun(new_state: ACPConnectionState, old_state: ACPConnectionState)
---Create a new ACP client instance
---@param config ACPConfig
---@return ACPClient
function ACPClient:new(config)
local client = setmetatable({
id_counter = 0,
protocol_version = 1,
capabilities = {
fs = {
readTextFile = true,
writeTextFile = true,
},
},
pending_responses = {},
callbacks = {},
transport = nil,
config = config or {},
state = "disconnected",
reconnect_count = 0,
heartbeat_timer = nil,
}, { __index = self })
client:_setup_transport()
return client
end
---Setup transport layer
function ACPClient:_setup_transport()
local transport_type = self.config.transport_type or "stdio"
if transport_type == "stdio" then
self.transport = self:_create_stdio_transport()
elseif transport_type == "websocket" then
self.transport = self:_create_websocket_transport()
elseif transport_type == "tcp" then
self.transport = self:_create_tcp_transport()
else
error("Unsupported transport type: " .. transport_type)
end
end
---Set connection state
---@param state ACPConnectionState
function ACPClient:_set_state(state)
local old_state = self.state
self.state = state
if self.config.on_state_change then self.config.on_state_change(state, old_state) end
end
---Create error object
---@param code number
---@param message string
---@param data any?
---@return ACPError
function ACPClient:_create_error(code, message, data)
return {
code = code,
message = message,
data = data,
}
end
---Create stdio transport layer
function ACPClient:_create_stdio_transport()
local uv = vim.loop
local transport = {
stdin = nil,
stdout = nil,
process = nil,
}
function transport.send(transport_self, data)
if transport_self.stdin and not transport_self.stdin:is_closing() then
transport_self.stdin:write(data .. "\n")
return true
end
return false
end
function transport.start(transport_self, on_message)
self:_set_state("connecting")
local stdin = uv.new_pipe(false)
local stdout = uv.new_pipe(false)
local stderr = uv.new_pipe(false)
if not stdin or not stdout or not stderr then
self:_set_state("error")
error("Failed to create pipes for ACP agent")
end
local args = vim.deepcopy(self.config.args or {})
local env = self.config.env
-- Start with system environment and override with config env
local final_env = {}
local path = vim.fn.getenv("PATH")
if path then final_env[#final_env + 1] = "PATH=" .. path end
if env then
for k, v in pairs(env) do
final_env[#final_env + 1] = k .. "=" .. v
end
end
---@diagnostic disable-next-line: missing-fields
local handle = uv.spawn(self.config.command, {
args = args,
env = final_env,
stdio = { stdin, stdout, stderr },
}, function(code, signal)
vim.print("ACP agent exited with code " .. code .. " and signal " .. signal)
self:_set_state("disconnected")
if transport_self.process then
transport_self.process:close()
transport_self.process = nil
end
-- Handle auto-reconnect
if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then
self.reconnect_count = self.reconnect_count + 1
vim.defer_fn(function()
if self.state == "disconnected" then self:connect() end
end, 2000) -- Wait 2 seconds before reconnecting
end
end)
if not handle then
self:_set_state("error")
error("Failed to spawn ACP agent process")
end
transport_self.process = handle
transport_self.stdin = stdin
transport_self.stdout = stdout
self:_set_state("connected")
-- Read stdout
local buffer = ""
stdout:read_start(function(err, data)
-- if data then vim.print("ACP stdout: " .. vim.inspect(data)) end
if err then
vim.notify("ACP stdout error: " .. err, vim.log.levels.ERROR)
self:_set_state("error")
return
end
if data then
buffer = buffer .. data
-- Split on newlines and process complete JSON-RPC messages
local lines = vim.split(buffer, "\n", { plain = true })
buffer = lines[#lines] -- Keep incomplete line in buffer
for i = 1, #lines - 1 do
local line = vim.trim(lines[i])
if line ~= "" then
local ok, message = pcall(vim.json.decode, line)
if ok then
on_message(message)
else
vim.schedule(
function() vim.notify("Failed to parse JSON-RPC message: " .. line, vim.log.levels.WARN) end
)
end
end
end
end
end)
-- Read stderr for debugging
stderr:read_start(function(_, data)
if data then vim.schedule(function() vim.notify("ACP stderr: " .. data, vim.log.levels.DEBUG) end) end
end)
end
function transport.stop(transport_self)
if transport_self.process then
transport_self.process:close()
transport_self.process = nil
end
if transport_self.stdin then
transport_self.stdin:close()
transport_self.stdin = nil
end
if transport_self.stdout then
transport_self.stdout:close()
transport_self.stdout = nil
end
self:_set_state("disconnected")
end
return transport
end
---Create WebSocket transport layer (placeholder)
function ACPClient:_create_websocket_transport() error("WebSocket transport not implemented yet") end
---Create TCP transport layer (placeholder)
function ACPClient:_create_tcp_transport() error("TCP transport not implemented yet") end
---Generate next request ID
---@return number
function ACPClient:_next_id()
self.id_counter = self.id_counter + 1
return self.id_counter
end
---Send JSON-RPC request
---@param method string
---@param params table?
---@param callback? fun(result: table|nil, err: ACPError|nil)
---@return table|nil result
---@return ACPError|nil err
function ACPClient:_send_request(method, params, callback)
local id = self:_next_id()
local message = {
jsonrpc = "2.0",
id = id,
method = method,
params = params or {},
}
if callback then self.callbacks[id] = callback end
local data = vim.json.encode(message) .. "\n"
if not self.transport:send(data) then return nil end
if not callback then return self:_wait_response(id) end
end
function ACPClient:_wait_response(id)
local start_time = vim.loop.now()
local timeout = self.config.timeout or 100000
while vim.loop.now() - start_time < timeout do
vim.wait(10)
if self.pending_responses[id] then
local result, err = unpack(self.pending_responses[id])
self.pending_responses[id] = nil
return result, err
end
end
return nil, self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for response")
end
---Send JSON-RPC notification
---@param method string
---@param params table?
function ACPClient:_send_notification(method, params)
local message = {
jsonrpc = "2.0",
method = method,
params = params or {},
}
local data = vim.json.encode(message) .. "\n"
self.transport:send(data)
end
---Send JSON-RPC result
---@param id number
---@param result table
---@return nil
function ACPClient:_send_result(id, result)
local message = { jsonrpc = "2.0", id = id, result = result }
local data = vim.json.encode(message) .. "\n"
-- vim.print("Sending result: " .. data)
self.transport:send(data)
end
---Send JSON-RPC error
---@param id number
---@param message string
---@param code? number
---@return nil
function ACPClient:_send_error(id, message, code)
code = code or self.ERROR_CODES.TRANSPORT_ERROR
local msg = { jsonrpc = "2.0", id = id, error = { code = code, message = message } }
local data = vim.json.encode(msg) .. "\n"
self.transport:send(data)
end
---Handle received message
---@param message table
function ACPClient:_handle_message(message)
-- vim.print("Received message: " .. vim.inspect(message))
-- Check if this is a notification (has method but no id, or has both method and id for notifications)
if message.method and not message.result and not message.error then
-- This is a notification
self:_handle_notification(message.id, message.method, message.params)
elseif message.id and (message.result or message.error) then
local callback = self.callbacks[message.id]
if callback then
callback(message.result, message.error)
self.callbacks[message.id] = nil
else
self.pending_responses[message.id] = { message.result, message.error }
end
else
-- Unknown message type
vim.notify("Unknown message type: " .. vim.inspect(message), vim.log.levels.WARN)
end
end
---Handle notification
---@param method string
---@param params table
function ACPClient:_handle_notification(message_id, method, params)
if method == "session/update" then
self:_handle_session_update(params)
elseif method == "session/request_permission" then
self:_handle_request_permission(message_id, params)
elseif method == "fs/read_text_file" then
self:_handle_read_text_file(message_id, params)
elseif method == "fs/write_text_file" then
self:_handle_write_text_file(message_id, params)
else
vim.notify("Unknown notification method: " .. method, vim.log.levels.WARN)
end
end
---Handle session update notification
---@param params table
function ACPClient:_handle_session_update(params)
local session_id = params.sessionId
local update = params.update
if not session_id then
vim.notify("Received session/update without sessionId", vim.log.levels.WARN)
return
end
if not update then
vim.notify("Received session/update without update data", vim.log.levels.WARN)
return
end
if self.config.handlers and self.config.handlers.on_session_update then
self.config.handlers.on_session_update(update)
end
end
---Handle permission request notification
---@param message_id number
---@param params table
function ACPClient:_handle_request_permission(message_id, params)
local session_id = params.sessionId
local tool_call = params.toolCall
local options = params.options
if not session_id or not tool_call then return end
if self.config.handlers and self.config.handlers.on_request_permission then
self.config.handlers.on_request_permission(
tool_call,
options,
function(option_id)
self:_send_result(message_id, {
outcome = {
outcome = "selected",
optionId = option_id,
},
})
end
)
end
end
---Handle fs/read_text_file requests
---@param message_id number
---@param params table
function ACPClient:_handle_read_text_file(message_id, params)
local session_id = params.sessionId
local path = params.path
if not session_id or not path then return end
if self.config.handlers and self.config.handlers.on_read_file then
local content = self.config.handlers.on_read_file(path)
self:_send_result(message_id, { content = content })
end
end
---Handle fs/write_text_file requests
---@param message_id number
---@param params table
function ACPClient:_handle_write_text_file(message_id, params)
local session_id = params.sessionId
local path = params.path
local content = params.content
if not session_id or not path or not content then return end
if self.config.handlers and self.config.handlers.on_write_file then
local error = self.config.handlers.on_write_file(path, content)
self:_send_result(message_id, error == nil and vim.NIL or error)
end
end
---Start client
function ACPClient:connect()
if self.state ~= "disconnected" then return end
self.transport:start(function(message) self:_handle_message(message) end)
self:initialize()
end
---Stop client
function ACPClient:stop()
self.transport:stop()
self.pending_responses = {}
self.reconnect_count = 0
end
---Initialize protocol connection
function ACPClient:initialize()
if self.state ~= "connected" then
local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected")
return error
end
self:_set_state("initializing")
local result = self:_send_request("initialize", {
protocolVersion = self.protocol_version,
clientCapabilities = self.capabilities,
})
if not result then
self:_set_state("error")
vim.notify("Failed to initialize", vim.log.levels.ERROR)
return
end
-- Update protocol version and capabilities
self.protocol_version = result.protocolVersion
self.agent_capabilities = result.agentCapabilities
self.auth_methods = result.authMethods or {}
-- Check if we need to authenticate
local auth_method = self.config.auth_method
if auth_method then
vim.print("Authenticating with method " .. auth_method)
self:authenticate(auth_method)
self:_set_state("ready")
else
vim.print("No authentication method found or specified")
self:_set_state("ready")
end
end
---Authentication (if required)
---@param method_id string
function ACPClient:authenticate(method_id)
return self:_send_request("authenticate", {
methodId = method_id,
})
end
---Create new session
---@param cwd string
---@param mcp_servers table[]?
---@return string|nil session_id
---@return ACPError|nil err
function ACPClient:create_session(cwd, mcp_servers)
local result, err = self:_send_request("session/new", {
cwd = cwd,
mcpServers = mcp_servers or {},
})
if err then
vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR)
return nil, err
end
if not result then
err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result")
return nil, err
end
return result.sessionId, nil
end
---Load existing session
---@param session_id string
---@param cwd string
---@param mcp_servers table[]?
---@return table|nil result
function ACPClient:load_session(session_id, cwd, mcp_servers)
if not self.agent_capabilities or not self.agent_capabilities.loadSession then
vim.notify("Agent does not support loading sessions", vim.log.levels.WARN)
return
end
return self:_send_request("session/load", {
sessionId = session_id,
cwd = cwd,
mcpServers = mcp_servers or {},
})
end
---Send prompt
---@param session_id string
---@param prompt table[]
---@param callback? fun(result: table|nil, err: ACPError|nil)
function ACPClient:send_prompt(session_id, prompt, callback)
local params = {
sessionId = session_id,
prompt = prompt,
}
return self:_send_request("session/prompt", params, callback)
end
---Cancel session
---@param session_id string
function ACPClient:cancel_session(session_id)
self:_send_notification("session/cancel", {
sessionId = session_id,
})
end
---Helper function: Create text content block
---@param text string
---@param annotations table?
---@return table
function ACPClient:create_text_content(text, annotations)
return {
type = "text",
text = text,
annotations = annotations,
}
end
---Helper function: Create image content block
---@param data string Base64 encoded image data
---@param mime_type string
---@param uri string?
---@param annotations table?
---@return table
function ACPClient:create_image_content(data, mime_type, uri, annotations)
return {
type = "image",
data = data,
mimeType = mime_type,
uri = uri,
annotations = annotations,
}
end
---Helper function: Create audio content block
---@param data string Base64 encoded audio data
---@param mime_type string
---@param annotations table?
---@return table
function ACPClient:create_audio_content(data, mime_type, annotations)
return {
type = "audio",
data = data,
mimeType = mime_type,
annotations = annotations,
}
end
---Helper function: Create resource link content block
---@param uri string
---@param name string
---@param description string?
---@param mime_type string?
---@param size number?
---@param title string?
---@param annotations table?
---@return table
function ACPClient:create_resource_link_content(uri, name, description, mime_type, size, title, annotations)
return {
type = "resource_link",
uri = uri,
name = name,
description = description,
mimeType = mime_type,
size = size,
title = title,
annotations = annotations,
}
end
---Helper function: Create embedded resource content block
---@param resource table
---@param annotations table?
---@return table
function ACPClient:create_resource_content(resource, annotations)
return {
type = "resource",
resource = resource,
annotations = annotations,
}
end
---Helper function: Create text resource
---@param uri string
---@param text string
---@param mime_type string?
---@return table
function ACPClient:create_text_resource(uri, text, mime_type)
return {
uri = uri,
text = text,
mimeType = mime_type,
}
end
---Helper function: Create binary resource
---@param uri string
---@param blob string Base64 encoded binary data
---@param mime_type string?
---@return table
function ACPClient:create_blob_resource(uri, blob, mime_type)
return {
uri = uri,
blob = blob,
mimeType = mime_type,
}
end
---Convenience method: Check if client is ready
---@return boolean
function ACPClient:is_ready() return self.state == "ready" end
---Convenience method: Check if client is connected
---@return boolean
function ACPClient:is_connected() return self.state ~= "disconnected" and self.state ~= "error" end
---Convenience method: Get current state
---@return ACPConnectionState
function ACPClient:get_state() return self.state end
---Convenience method: Wait for client to be ready
---@param callback function
---@param timeout number? Timeout in milliseconds
function ACPClient:wait_ready(callback, timeout)
if self:is_ready() then
callback(nil)
return
end
local timeout_ms = timeout or 10000 -- 10 seconds default
local start_time = vim.loop.now()
local function check_ready()
if self:is_ready() then
callback(nil)
elseif self.state == "error" then
callback(self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Client entered error state while waiting"))
elseif vim.loop.now() - start_time > timeout_ms then
callback(self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for client to be ready"))
else
vim.defer_fn(check_ready, 100) -- Check every 100ms
end
end
check_ready()
end
---Convenience method: Send simple text prompt
---@param session_id string
---@param text string
function ACPClient:send_text_prompt(session_id, text)
local prompt = { self:create_text_content(text) }
self:send_prompt(session_id, prompt)
end
return ACPClient