fix: make acp client send requests async (#2848)
This commit is contained in:
@@ -242,7 +242,6 @@ function ACPClient:new(config)
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
debug_log_file = nil,
|
debug_log_file = nil,
|
||||||
pending_responses = {},
|
|
||||||
callbacks = {},
|
callbacks = {},
|
||||||
transport = nil,
|
transport = nil,
|
||||||
config = config or {},
|
config = config or {},
|
||||||
@@ -388,7 +387,7 @@ function ACPClient:_create_stdio_transport()
|
|||||||
if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then
|
if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then
|
||||||
self.reconnect_count = self.reconnect_count + 1
|
self.reconnect_count = self.reconnect_count + 1
|
||||||
vim.defer_fn(function()
|
vim.defer_fn(function()
|
||||||
if self.state == "disconnected" then self:connect() end
|
if self.state == "disconnected" then self:connect(function(_err) end) end
|
||||||
end, 2000) -- Wait 2 seconds before reconnecting
|
end, 2000) -- Wait 2 seconds before reconnecting
|
||||||
end
|
end
|
||||||
end)
|
end)
|
||||||
@@ -493,9 +492,7 @@ end
|
|||||||
---Send JSON-RPC request
|
---Send JSON-RPC request
|
||||||
---@param method string
|
---@param method string
|
||||||
---@param params table?
|
---@param params table?
|
||||||
---@param callback? fun(result: table|nil, err: avante.acp.ACPError|nil)
|
---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
|
||||||
---@return table|nil result
|
|
||||||
---@return avante.acp.ACPError|nil err
|
|
||||||
function ACPClient:_send_request(method, params, callback)
|
function ACPClient:_send_request(method, params, callback)
|
||||||
local id = self:_next_id()
|
local id = self:_next_id()
|
||||||
local message = {
|
local message = {
|
||||||
@@ -505,30 +502,11 @@ function ACPClient:_send_request(method, params, callback)
|
|||||||
params = params or {},
|
params = params or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
if callback then self.callbacks[id] = callback end
|
self.callbacks[id] = callback
|
||||||
|
|
||||||
local data = vim.json.encode(message)
|
local data = vim.json.encode(message)
|
||||||
self:_debug_log("request: " .. data .. string.rep("=", 100) .. "\n")
|
self:_debug_log("request: " .. data .. string.rep("=", 100) .. "\n")
|
||||||
if not self.transport:send(data) then return nil end
|
self.transport:send(data)
|
||||||
|
|
||||||
if not callback then return self:_wait_response(id) end
|
|
||||||
end
|
|
||||||
|
|
||||||
function ACPClient:_wait_response(id)
|
|
||||||
local start_time = vim.loop.now()
|
|
||||||
local timeout = self.config.timeout or 100000
|
|
||||||
|
|
||||||
while vim.loop.now() - start_time < timeout do
|
|
||||||
vim.wait(10)
|
|
||||||
|
|
||||||
if self.pending_responses[id] then
|
|
||||||
local result, err = unpack(self.pending_responses[id])
|
|
||||||
self.pending_responses[id] = nil
|
|
||||||
return result, err
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return nil, self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for response")
|
|
||||||
end
|
end
|
||||||
|
|
||||||
---Send JSON-RPC notification
|
---Send JSON-RPC notification
|
||||||
@@ -584,8 +562,6 @@ function ACPClient:_handle_message(message)
|
|||||||
if callback then
|
if callback then
|
||||||
callback(message.result, message.error)
|
callback(message.result, message.error)
|
||||||
self.callbacks[message.id] = nil
|
self.callbacks[message.id] = nil
|
||||||
else
|
|
||||||
self.pending_responses[message.id] = { message.result, message.error }
|
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
-- Unknown message type
|
-- Unknown message type
|
||||||
@@ -715,111 +691,139 @@ function ACPClient:_handle_write_text_file(message_id, params)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---Start client
|
---Start client
|
||||||
function ACPClient:connect()
|
---@param callback fun(err: avante.acp.ACPError|nil)
|
||||||
if self.state ~= "disconnected" then return end
|
function ACPClient:connect(callback)
|
||||||
|
callback = callback or function() end
|
||||||
|
|
||||||
self.transport:start(function(message) self:_handle_message(message) end)
|
if self.state ~= "disconnected" then
|
||||||
|
callback(nil)
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
self:initialize()
|
self.transport:start(vim.schedule_wrap(function(message) self:_handle_message(message) end))
|
||||||
|
|
||||||
|
self:initialize(callback)
|
||||||
end
|
end
|
||||||
|
|
||||||
---Stop client
|
---Stop client
|
||||||
function ACPClient:stop()
|
function ACPClient:stop()
|
||||||
self.transport:stop()
|
self.transport:stop()
|
||||||
self:_close_debug_log()
|
self:_close_debug_log()
|
||||||
self.pending_responses = {}
|
|
||||||
self.reconnect_count = 0
|
self.reconnect_count = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
---Initialize protocol connection
|
---Initialize protocol connection
|
||||||
function ACPClient:initialize()
|
---@param callback fun(err: avante.acp.ACPError|nil)
|
||||||
|
function ACPClient:initialize(callback)
|
||||||
|
callback = callback or function() end
|
||||||
|
|
||||||
if self.state ~= "connected" then
|
if self.state ~= "connected" then
|
||||||
local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected")
|
local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected")
|
||||||
return error
|
callback(error)
|
||||||
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
self:_set_state("initializing")
|
self:_set_state("initializing")
|
||||||
|
|
||||||
local result = self:_send_request("initialize", {
|
self:_send_request("initialize", {
|
||||||
protocolVersion = self.protocol_version,
|
protocolVersion = self.protocol_version,
|
||||||
clientCapabilities = self.capabilities,
|
clientCapabilities = self.capabilities,
|
||||||
})
|
}, function(result, err)
|
||||||
|
if err or not result then
|
||||||
|
self:_set_state("error")
|
||||||
|
vim.schedule(function() vim.notify("Failed to initialize", vim.log.levels.ERROR) end)
|
||||||
|
callback(err or self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to initialize: missing result"))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
if not result then
|
-- Update protocol version and capabilities
|
||||||
self:_set_state("error")
|
self.protocol_version = result.protocolVersion
|
||||||
vim.notify("Failed to initialize", vim.log.levels.ERROR)
|
self.agent_capabilities = result.agentCapabilities
|
||||||
return
|
self.auth_methods = result.authMethods or {}
|
||||||
end
|
|
||||||
|
|
||||||
-- Update protocol version and capabilities
|
-- Check if we need to authenticate
|
||||||
self.protocol_version = result.protocolVersion
|
local auth_method = self.config.auth_method
|
||||||
self.agent_capabilities = result.agentCapabilities
|
|
||||||
self.auth_methods = result.authMethods or {}
|
|
||||||
|
|
||||||
-- Check if we need to authenticate
|
if auth_method then
|
||||||
local auth_method = self.config.auth_method
|
Utils.debug("Authenticating with method " .. auth_method)
|
||||||
|
self:authenticate(auth_method, function(auth_err)
|
||||||
if auth_method then
|
if auth_err then
|
||||||
Utils.debug("Authenticating with method " .. auth_method)
|
callback(auth_err)
|
||||||
self:authenticate(auth_method)
|
else
|
||||||
self:_set_state("ready")
|
self:_set_state("ready")
|
||||||
else
|
callback(nil)
|
||||||
Utils.debug("No authentication method found or specified")
|
end
|
||||||
self:_set_state("ready")
|
end)
|
||||||
end
|
else
|
||||||
|
Utils.debug("No authentication method found or specified")
|
||||||
|
self:_set_state("ready")
|
||||||
|
callback(nil)
|
||||||
|
end
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
---Authentication (if required)
|
---Authentication (if required)
|
||||||
---@param method_id string
|
---@param method_id string
|
||||||
function ACPClient:authenticate(method_id)
|
---@param callback fun(err: avante.acp.ACPError|nil)
|
||||||
return self:_send_request("authenticate", {
|
function ACPClient:authenticate(method_id, callback)
|
||||||
|
callback = callback or function() end
|
||||||
|
|
||||||
|
self:_send_request("authenticate", {
|
||||||
methodId = method_id,
|
methodId = method_id,
|
||||||
})
|
}, function(result, err) callback(err) end)
|
||||||
end
|
end
|
||||||
|
|
||||||
---Create new session
|
---Create new session
|
||||||
---@param cwd string
|
---@param cwd string
|
||||||
---@param mcp_servers table[]?
|
---@param mcp_servers table[]?
|
||||||
---@return string|nil session_id
|
---@param callback fun(session_id: string|nil, err: avante.acp.ACPError|nil)
|
||||||
---@return avante.acp.ACPError|nil err
|
function ACPClient:create_session(cwd, mcp_servers, callback)
|
||||||
function ACPClient:create_session(cwd, mcp_servers)
|
callback = callback or function() end
|
||||||
local result, err = self:_send_request("session/new", {
|
|
||||||
|
self:_send_request("session/new", {
|
||||||
cwd = cwd,
|
cwd = cwd,
|
||||||
mcpServers = mcp_servers or {},
|
mcpServers = mcp_servers or {},
|
||||||
})
|
}, function(result, err)
|
||||||
if err then
|
if err then
|
||||||
vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR)
|
vim.schedule(function() vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR) end)
|
||||||
return nil, err
|
callback(nil, err)
|
||||||
end
|
return
|
||||||
if not result then
|
end
|
||||||
err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result")
|
if not result then
|
||||||
return nil, err
|
local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result")
|
||||||
end
|
callback(nil, error)
|
||||||
return result.sessionId, nil
|
return
|
||||||
|
end
|
||||||
|
callback(result.sessionId, nil)
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
---Load existing session
|
---Load existing session
|
||||||
---@param session_id string
|
---@param session_id string
|
||||||
---@param cwd string
|
---@param cwd string
|
||||||
---@param mcp_servers table[]?
|
---@param mcp_servers table[]?
|
||||||
---@return table|nil result
|
---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
|
||||||
function ACPClient:load_session(session_id, cwd, mcp_servers)
|
function ACPClient:load_session(session_id, cwd, mcp_servers, callback)
|
||||||
|
callback = callback or function() end
|
||||||
|
|
||||||
if not self.agent_capabilities or not self.agent_capabilities.loadSession then
|
if not self.agent_capabilities or not self.agent_capabilities.loadSession then
|
||||||
vim.notify("Agent does not support loading sessions", vim.log.levels.WARN)
|
vim.schedule(function() vim.notify("Agent does not support loading sessions", vim.log.levels.WARN) end)
|
||||||
|
local err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Agent does not support loading sessions")
|
||||||
|
callback(nil, err)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
return self:_send_request("session/load", {
|
self:_send_request("session/load", {
|
||||||
sessionId = session_id,
|
sessionId = session_id,
|
||||||
cwd = cwd,
|
cwd = cwd,
|
||||||
mcpServers = mcp_servers or {},
|
mcpServers = mcp_servers or {},
|
||||||
})
|
}, callback)
|
||||||
end
|
end
|
||||||
|
|
||||||
---Send prompt
|
---Send prompt
|
||||||
---@param session_id string
|
---@param session_id string
|
||||||
---@param prompt table[]
|
---@param prompt table[]
|
||||||
---@param callback? fun(result: table|nil, err: avante.acp.ACPError|nil)
|
---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
|
||||||
function ACPClient:send_prompt(session_id, prompt, callback)
|
function ACPClient:send_prompt(session_id, prompt, callback)
|
||||||
local params = {
|
local params = {
|
||||||
sessionId = session_id,
|
sessionId = session_id,
|
||||||
@@ -980,9 +984,10 @@ end
|
|||||||
---Convenience method: Send simple text prompt
|
---Convenience method: Send simple text prompt
|
||||||
---@param session_id string
|
---@param session_id string
|
||||||
---@param text string
|
---@param text string
|
||||||
function ACPClient:send_text_prompt(session_id, text)
|
---@param callback fun(result: table|nil, err: avante.acp.ACPError|nil)
|
||||||
|
function ACPClient:send_text_prompt(session_id, text, callback)
|
||||||
local prompt = { self:create_text_content(text) }
|
local prompt = { self:create_text_content(text) }
|
||||||
self:send_prompt(session_id, prompt)
|
self:send_prompt(session_id, prompt, callback)
|
||||||
end
|
end
|
||||||
|
|
||||||
return ACPClient
|
return ACPClient
|
||||||
|
|||||||
@@ -977,6 +977,7 @@ function M._stream_acp(opts)
|
|||||||
return message
|
return message
|
||||||
end
|
end
|
||||||
local acp_client = opts.acp_client
|
local acp_client = opts.acp_client
|
||||||
|
local session_id = opts.acp_session_id
|
||||||
if not acp_client then
|
if not acp_client then
|
||||||
local acp_config = vim.tbl_deep_extend("force", acp_provider, {
|
local acp_config = vim.tbl_deep_extend("force", acp_provider, {
|
||||||
---@type ACPHandlers
|
---@type ACPHandlers
|
||||||
@@ -1298,22 +1299,46 @@ function M._stream_acp(opts)
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
acp_client = ACPClient:new(acp_config)
|
acp_client = ACPClient:new(acp_config)
|
||||||
acp_client:connect()
|
|
||||||
|
|
||||||
-- Register ACP client for global cleanup on exit (Fix Issue #2749)
|
acp_client:connect(function(conn_err)
|
||||||
local client_id = "acp_" .. tostring(acp_client) .. "_" .. os.time()
|
if conn_err then
|
||||||
local ok, Avante = pcall(require, "avante")
|
opts.on_stop({ reason = "error", error = conn_err })
|
||||||
if ok and Avante.register_acp_client then Avante.register_acp_client(client_id, acp_client) end
|
return
|
||||||
|
end
|
||||||
|
|
||||||
-- If we create a new client and it does not support sesion loading,
|
-- Register ACP client for global cleanup on exit (Fix Issue #2749)
|
||||||
-- remove the old session
|
local client_id = "acp_" .. tostring(acp_client) .. "_" .. os.time()
|
||||||
if not acp_client.agent_capabilities.loadSession then opts.acp_session_id = nil end
|
local ok, Avante = pcall(require, "avante")
|
||||||
if opts.on_save_acp_client then opts.on_save_acp_client(acp_client) end
|
if ok and Avante.register_acp_client then Avante.register_acp_client(client_id, acp_client) end
|
||||||
|
|
||||||
|
-- If we create a new client and it does not support sesion loading,
|
||||||
|
-- remove the old session
|
||||||
|
if not acp_client.agent_capabilities.loadSession then opts.acp_session_id = nil end
|
||||||
|
if opts.on_save_acp_client then opts.on_save_acp_client(acp_client) end
|
||||||
|
|
||||||
|
session_id = opts.acp_session_id
|
||||||
|
if not session_id then
|
||||||
|
M._create_acp_session_and_continue(opts, acp_client)
|
||||||
|
else
|
||||||
|
if opts.just_connect_acp_client then return end
|
||||||
|
M._continue_stream_acp(opts, acp_client, session_id)
|
||||||
|
end
|
||||||
|
end)
|
||||||
|
return
|
||||||
|
elseif not session_id then
|
||||||
|
M._create_acp_session_and_continue(opts, acp_client)
|
||||||
|
return
|
||||||
end
|
end
|
||||||
local session_id = opts.acp_session_id
|
|
||||||
if not session_id then
|
if opts.just_connect_acp_client then return end
|
||||||
local project_root = Utils.root.get()
|
M._continue_stream_acp(opts, acp_client, session_id)
|
||||||
local session_id_, err = acp_client:create_session(project_root, {})
|
end
|
||||||
|
|
||||||
|
---@param opts AvanteLLMStreamOptions
|
||||||
|
---@param acp_client avante.acp.ACPClient
|
||||||
|
function M._create_acp_session_and_continue(opts, acp_client)
|
||||||
|
local project_root = Utils.root.get()
|
||||||
|
acp_client:create_session(project_root, {}, function(session_id_, err)
|
||||||
if err then
|
if err then
|
||||||
opts.on_stop({ reason = "error", error = err })
|
opts.on_stop({ reason = "error", error = err })
|
||||||
return
|
return
|
||||||
@@ -1322,10 +1347,18 @@ function M._stream_acp(opts)
|
|||||||
opts.on_stop({ reason = "error", error = "Failed to create session" })
|
opts.on_stop({ reason = "error", error = "Failed to create session" })
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
session_id = session_id_
|
opts.acp_session_id = session_id_
|
||||||
if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id) end
|
if opts.on_save_acp_session_id then opts.on_save_acp_session_id(session_id_) end
|
||||||
end
|
|
||||||
if opts.just_connect_acp_client then return end
|
if opts.just_connect_acp_client then return end
|
||||||
|
M._continue_stream_acp(opts, acp_client, session_id_)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
---@param opts AvanteLLMStreamOptions
|
||||||
|
---@param acp_client avante.acp.ACPClient
|
||||||
|
---@param session_id string
|
||||||
|
function M._continue_stream_acp(opts, acp_client, session_id)
|
||||||
local prompt = {}
|
local prompt = {}
|
||||||
local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0
|
local donot_use_builtin_system_prompt = opts.history_messages ~= nil and #opts.history_messages > 0
|
||||||
if donot_use_builtin_system_prompt then
|
if donot_use_builtin_system_prompt then
|
||||||
|
|||||||
Reference in New Issue
Block a user