fix: make acp client send requests async (#2848)

This commit is contained in:
Peter Cardenas
2025-12-20 11:32:04 -08:00
committed by GitHub
parent 2afb70537f
commit cf352f6f46
2 changed files with 137 additions and 99 deletions

View File

@@ -242,7 +242,6 @@ function ACPClient:new(config)
}, },
}, },
debug_log_file = nil, debug_log_file = nil,
pending_responses = {},
callbacks = {}, callbacks = {},
transport = nil, transport = nil,
config = config or {}, config = config or {},
@@ -388,7 +387,7 @@ function ACPClient:_create_stdio_transport()
if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then
self.reconnect_count = self.reconnect_count + 1 self.reconnect_count = self.reconnect_count + 1
vim.defer_fn(function() vim.defer_fn(function()
if self.state == "disconnected" then self:connect() end if self.state == "disconnected" then self:connect(function(_err) end) end
end, 2000) -- Wait 2 seconds before reconnecting end, 2000) -- Wait 2 seconds before reconnecting
end end
end) end)
@@ -493,9 +492,7 @@ end
---Send JSON-RPC request ---Send JSON-RPC request
---@param method string ---@param method string
---@param params table? ---@param params table?
---@param callback? fun(result: table|nil, err: avante.acp.ACPError|nil) ---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
---@return table|nil result
---@return avante.acp.ACPError|nil err
function ACPClient:_send_request(method, params, callback) function ACPClient:_send_request(method, params, callback)
local id = self:_next_id() local id = self:_next_id()
local message = { local message = {
@@ -505,30 +502,11 @@ function ACPClient:_send_request(method, params, callback)
params = params or {}, params = params or {},
} }
if callback then self.callbacks[id] = callback end self.callbacks[id] = callback
local data = vim.json.encode(message) local data = vim.json.encode(message)
self:_debug_log("request: " .. data .. string.rep("=", 100) .. "\n") self:_debug_log("request: " .. data .. string.rep("=", 100) .. "\n")
if not self.transport:send(data) then return nil end self.transport:send(data)
if not callback then return self:_wait_response(id) end
end
function ACPClient:_wait_response(id)
local start_time = vim.loop.now()
local timeout = self.config.timeout or 100000
while vim.loop.now() - start_time < timeout do
vim.wait(10)
if self.pending_responses[id] then
local result, err = unpack(self.pending_responses[id])
self.pending_responses[id] = nil
return result, err
end
end
return nil, self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for response")
end end
---Send JSON-RPC notification ---Send JSON-RPC notification
@@ -584,8 +562,6 @@ function ACPClient:_handle_message(message)
if callback then if callback then
callback(message.result, message.error) callback(message.result, message.error)
self.callbacks[message.id] = nil self.callbacks[message.id] = nil
else
self.pending_responses[message.id] = { message.result, message.error }
end end
else else
-- Unknown message type -- Unknown message type
@@ -715,111 +691,139 @@ function ACPClient:_handle_write_text_file(message_id, params)
end end
---Start client ---Start client
function ACPClient:connect() ---@param callback fun(err: avante.acp.ACPError|nil)
if self.state ~= "disconnected" then return end function ACPClient:connect(callback)
callback = callback or function() end
self.transport:start(function(message) self:_handle_message(message) end) if self.state ~= "disconnected" then
callback(nil)
return
end
self:initialize() self.transport:start(vim.schedule_wrap(function(message) self:_handle_message(message) end))
self:initialize(callback)
end end
---Stop client ---Stop client
function ACPClient:stop() function ACPClient:stop()
self.transport:stop() self.transport:stop()
self:_close_debug_log() self:_close_debug_log()
self.pending_responses = {}
self.reconnect_count = 0 self.reconnect_count = 0
end end
---Initialize protocol connection ---Initialize protocol connection
function ACPClient:initialize() ---@param callback fun(err: avante.acp.ACPError|nil)
function ACPClient:initialize(callback)
callback = callback or function() end
if self.state ~= "connected" then if self.state ~= "connected" then
local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected") local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected")
return error callback(error)
return
end end
self:_set_state("initializing") self:_set_state("initializing")
local result = self:_send_request("initialize", { self:_send_request("initialize", {
protocolVersion = self.protocol_version, protocolVersion = self.protocol_version,
clientCapabilities = self.capabilities, clientCapabilities = self.capabilities,
}) }, function(result, err)
if err or not result then
self:_set_state("error")
vim.schedule(function() vim.notify("Failed to initialize", vim.log.levels.ERROR) end)
callback(err or self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to initialize: missing result"))
return
end
if not result then -- Update protocol version and capabilities
self:_set_state("error") self.protocol_version = result.protocolVersion
vim.notify("Failed to initialize", vim.log.levels.ERROR) self.agent_capabilities = result.agentCapabilities
return self.auth_methods = result.authMethods or {}
end
-- Update protocol version and capabilities -- Check if we need to authenticate
self.protocol_version = result.protocolVersion local auth_method = self.config.auth_method
self.agent_capabilities = result.agentCapabilities
self.auth_methods = result.authMethods or {}
-- Check if we need to authenticate if auth_method then
local auth_method = self.config.auth_method Utils.debug("Authenticating with method " .. auth_method)
self:authenticate(auth_method, function(auth_err)
if auth_method then if auth_err then
Utils.debug("Authenticating with method " .. auth_method) callback(auth_err)
self:authenticate(auth_method) else
self:_set_state("ready") self:_set_state("ready")
else callback(nil)
Utils.debug("No authentication method found or specified") end
self:_set_state("ready") end)
end else
Utils.debug("No authentication method found or specified")
self:_set_state("ready")
callback(nil)
end
end)
end end
---Authentication (if required) ---Authentication (if required)
---@param method_id string ---@param method_id string
function ACPClient:authenticate(method_id) ---@param callback fun(err: avante.acp.ACPError|nil)
return self:_send_request("authenticate", { function ACPClient:authenticate(method_id, callback)
callback = callback or function() end
self:_send_request("authenticate", {
methodId = method_id, methodId = method_id,
}) }, function(result, err) callback(err) end)
end end
---Create new session ---Create new session
---@param cwd string ---@param cwd string
---@param mcp_servers table[]? ---@param mcp_servers table[]?
---@return string|nil session_id ---@param callback fun(session_id: string|nil, err: avante.acp.ACPError|nil)
---@return avante.acp.ACPError|nil err function ACPClient:create_session(cwd, mcp_servers, callback)
function ACPClient:create_session(cwd, mcp_servers) callback = callback or function() end
local result, err = self:_send_request("session/new", {
self:_send_request("session/new", {
cwd = cwd, cwd = cwd,
mcpServers = mcp_servers or {}, mcpServers = mcp_servers or {},
}) }, function(result, err)
if err then if err then
vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR) vim.schedule(function() vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR) end)
return nil, err callback(nil, err)
end return
if not result then end
err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result") if not result then
return nil, err local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result")
end callback(nil, error)
return result.sessionId, nil return
end
callback(result.sessionId, nil)
end)
end end
---Load existing session ---Load existing session
---@param session_id string ---@param session_id string
---@param cwd string ---@param cwd string
---@param mcp_servers table[]? ---@param mcp_servers table[]?
---@return table|nil result ---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
function ACPClient:load_session(session_id, cwd, mcp_servers) function ACPClient:load_session(session_id, cwd, mcp_servers, callback)
callback = callback or function() end
if not self.agent_capabilities or not self.agent_capabilities.loadSession then if not self.agent_capabilities or not self.agent_capabilities.loadSession then
vim.notify("Agent does not support loading sessions", vim.log.levels.WARN) vim.schedule(function() vim.notify("Agent does not support loading sessions", vim.log.levels.WARN) end)
local err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Agent does not support loading sessions")
callback(nil, err)
return return
end end
return self:_send_request("session/load", { self:_send_request("session/load", {
sessionId = session_id, sessionId = session_id,
cwd = cwd, cwd = cwd,
mcpServers = mcp_servers or {}, mcpServers = mcp_servers or {},
}) }, callback)
end end
---Send prompt ---Send prompt
---@param session_id string ---@param session_id string
---@param prompt table[] ---@param prompt table[]
---@param callback? fun(result: table|nil, err: avante.acp.ACPError|nil) ---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
function ACPClient:send_prompt(session_id, prompt, callback) function ACPClient:send_prompt(session_id, prompt, callback)
local params = { local params = {
sessionId = session_id, sessionId = session_id,
@@ -980,9 +984,10 @@ end
---Convenience method: Send simple text prompt ---Convenience method: Send simple text prompt
---@param session_id string ---@param session_id string
---@param text string ---@param text string
function ACPClient:send_text_prompt(session_id, text) ---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
function ACPClient:send_text_prompt(session_id, text, callback)
local prompt = { self:create_text_content(text) } local prompt = { self:create_text_content(text) }
self:send_prompt(session_id, prompt) self:send_prompt(session_id, prompt, callback)
end end
return ACPClient return ACPClient

View File

@@ -977,6 +977,7 @@ function M._stream_acp(opts)
return message return message
end end
local acp_client = opts.acp_client local acp_client = opts.acp_client
local session_id = opts.acp_session_id
if not acp_client then if not acp_client then
local acp_config = vim.tbl_deep_extend("force", acp_provider, { local acp_config = vim.tbl_deep_extend("force", acp_provider, {
---@type ACPHandlers ---@type ACPHandlers
@@ -1298,22 +1299,46 @@ function M._stream_acp(opts)
}, },
}) })
acp_client = ACPClient:new(acp_config) acp_client = ACPClient:new(acp_config)
acp_client:connect()
-- Register ACP client for global cleanup on exit (Fix Issue #2749) acp_client:connect(function(conn_err)
local client_id = "acp_" .. tostring(acp_client) .. "_" .. os.time() if conn_err then
local ok, Avante = pcall(require, "avante") opts.on_stop({ reason = "error", error = conn_err })
if ok and Avante.register_acp_client then Avante.register_acp_client(client_id, acp_client) end return
end
-- If we create a new client and it does not support sesion loading, -- Register ACP client for global cleanup on exit (Fix Issue #2749)
-- remove the old session local client_id = "acp_" .. tostring(acp_client) .. "_" .. os.time()
if not acp_client.agent_capabilities.loadSession then opts.acp_session_id = nil end local ok, Avante = pcall(require, "avante")
if opts.on_save_acp_client then opts.on_save_acp_client(acp_client) end if ok and Avante.register_acp_client then Avante.register_acp_client(client_id, acp_client) end
-- If we create a new client and it does not support sesion loading,
-- remove the old session
if not acp_client.agent_capabilities.loadSession then opts.acp_session_id = nil end
if opts.on_save_acp_client then opts.on_save_acp_client(acp_client) end
session_id = opts.acp_session_id
if not session_id then
M._create_acp_session_and_continue(opts, acp_client)
else
if opts.just_connect_acp_client then return end
M._continue_stream_acp(opts, acp_client, session_id)
end
end)
return
elseif not session_id then
M._create_acp_session_and_continue(opts, acp_client)
return
end end
local session_id = opts.acp_session_id
if not session_id then if opts.just_connect_acp_client then return end
local project_root = Utils.root.get() M._continue_stream_acp(opts, acp_client, session_id)
local session_id_, err = acp_client:create_session(project_root, {}) end
---@param opts AvanteLLMStreamOptions
---@param acp_client avante.acp.ACPClient
function M._create_acp_session_and_continue(opts, acp_client)
local project_root = Utils.root.get()
acp_client:create_session(project_root, {}, function(session_id_, err)
if err then if err then
opts.on_stop({ reason = "error", error = err }) opts.on_stop({ reason = "error", error = err })
return return
@@ -1322,10 +1347,18 @@ function M._stream_acp(opts)
opts.on_stop({ reason = "error", error = "Failed to create session" }) opts.on_stop({ reason = "error", error = "Failed to create session" })
return return
end end
session_id = session_id_ opts.acp_session_id = session_id_
if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id) end if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id_) end
end
if opts.just_connect_acp_client then return end if opts.just_connect_acp_client then return end
M._continue_stream_acp(opts, acp_client, session_id_)
end)
end
---@param opts AvanteLLMStreamOptions
---@param acp_client avante.acp.ACPClient
---@param session_id string
function M._continue_stream_acp(opts, acp_client, session_id)
local prompt = {} local prompt = {}
local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0 local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0
if donot_use_builtin_system_prompt then if donot_use_builtin_system_prompt then