feat: support acp (#2649)
This commit is contained in:
@@ -3,6 +3,7 @@ local fn = vim.fn
|
||||
local uv = vim.uv
|
||||
|
||||
local curl = require("plenary.curl")
|
||||
local ACPClient = require("avante.libs.acp_client")
|
||||
|
||||
local Utils = require("avante.utils")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
@@ -13,6 +14,7 @@ local Providers = require("avante.providers")
|
||||
local LLMToolHelpers = require("avante.llm_tools.helpers")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
local History = require("avante.history")
|
||||
local Selector = require("avante.ui.selector")
|
||||
|
||||
---@class avante.LLM
|
||||
local M = {}
|
||||
@@ -486,6 +488,7 @@ end
|
||||
---@param opts AvanteGeneratePromptsOptions
|
||||
---@return integer
|
||||
function M.calculate_tokens(opts)
|
||||
if Config.acp_providers[Config.provider] then return 0 end
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt)
|
||||
for _, message in ipairs(prompt_opts.messages) do
|
||||
@@ -795,11 +798,339 @@ local function stop_retry_timer()
|
||||
end
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream_acp(opts)
|
||||
---@type table<string, avante.HistoryMessage>
|
||||
local tool_call_messages = {}
|
||||
local acp_provider = Config.acp_providers[Config.provider]
|
||||
local on_messages_add = function(messages)
|
||||
if opts.on_messages_add then opts.on_messages_add(messages) end
|
||||
vim.schedule(function() vim.cmd("redraw") end)
|
||||
end
|
||||
local function add_tool_call_message(update)
|
||||
local message = History.Message:new("assistant", {
|
||||
type = "tool_use",
|
||||
id = update.toolCallId,
|
||||
name = update.kind .. "(" .. update.title .. ")",
|
||||
})
|
||||
if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end
|
||||
tool_call_messages[update.toolCallId] = message
|
||||
if update.rawInput then
|
||||
local path = update.rawInput.path or update.rawInput.file_path
|
||||
if path then
|
||||
local relative_path = Utils.relative_path(path)
|
||||
message.displayed_tool_name = update.title .. "(" .. relative_path .. ")"
|
||||
end
|
||||
local pattern = update.rawInput.pattern or update.rawInput.search
|
||||
if pattern then message.displayed_tool_name = update.title .. "(" .. pattern .. ")" end
|
||||
local command = update.rawInput.command or update.rawInput.command_line
|
||||
if command then message.displayed_tool_name = update.title .. "(" .. command .. ")" end
|
||||
local description = update.rawInput.description
|
||||
if description then
|
||||
message.tool_use_logs = message.tool_use_logs or {}
|
||||
table.insert(message.tool_use_logs, description)
|
||||
end
|
||||
end
|
||||
on_messages_add({ message })
|
||||
return message
|
||||
end
|
||||
local acp_client = opts.acp_client
|
||||
if not acp_client then
|
||||
local acp_config = vim.tbl_deep_extend("force", acp_provider, {
|
||||
handlers = {
|
||||
on_session_update = function(update)
|
||||
if update.sessionUpdate == "plan" then
|
||||
local todos = {}
|
||||
for idx, entry in ipairs(update.entries) do
|
||||
local status = "todo"
|
||||
if entry.status == "in_progress" then status = "doing" end
|
||||
if entry.status == "completed" then status = "done" end
|
||||
---@type avante.TODO
|
||||
local todo = {
|
||||
id = tostring(idx),
|
||||
content = entry.content,
|
||||
status = status,
|
||||
priority = entry.priority,
|
||||
}
|
||||
table.insert(todos, todo)
|
||||
end
|
||||
vim.schedule(function()
|
||||
if opts.update_todos then opts.update_todos(todos) end
|
||||
end)
|
||||
return
|
||||
end
|
||||
if update.sessionUpdate == "agent_message_chunk" then
|
||||
if update.content.type == "text" then
|
||||
local messages = opts.get_history_messages()
|
||||
local last_message = messages[#messages]
|
||||
if last_message and last_message.message.role == "assistant" then
|
||||
local has_text = false
|
||||
local content = last_message.message.content
|
||||
if type(content) == "string" then
|
||||
last_message.message.content = last_message.message.content .. update.content.text
|
||||
has_text = true
|
||||
elseif type(content) == "table" then
|
||||
for idx, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
content[idx] = item .. update.content.text
|
||||
has_text = true
|
||||
end
|
||||
if type(item) == "table" and item.type == "text" then
|
||||
item.text = item.text .. update.content.text
|
||||
has_text = true
|
||||
end
|
||||
end
|
||||
end
|
||||
if has_text then
|
||||
on_messages_add({ last_message })
|
||||
return
|
||||
end
|
||||
end
|
||||
local message = History.Message:new("assistant", update.content.text)
|
||||
on_messages_add({ message })
|
||||
end
|
||||
end
|
||||
if update.sessionUpdate == "agent_thought_chunk" then
|
||||
if update.content.type == "text" then
|
||||
local message = History.Message:new("assistant", {
|
||||
type = "thinking",
|
||||
thinking = update.content.text,
|
||||
})
|
||||
on_messages_add({ message })
|
||||
end
|
||||
end
|
||||
if update.sessionUpdate == "tool_call" then add_tool_call_message(update) end
|
||||
if update.sessionUpdate == "tool_call_update" then
|
||||
local tool_call_message = tool_call_messages[update.toolCallId]
|
||||
if not tool_call_message then
|
||||
tool_call_message = History.Message:new("assistant", {
|
||||
type = "tool_use",
|
||||
id = update.toolCallId,
|
||||
name = "",
|
||||
})
|
||||
local update_content = update.content
|
||||
if type(update_content) == "table" then
|
||||
for _, item in ipairs(update_content) do
|
||||
if item.path then
|
||||
local relative_path = Utils.relative_path(item.path)
|
||||
tool_call_message.displayed_tool_name = "Edit(" .. relative_path .. ")"
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if not tool_call_message.displayed_tool_name then
|
||||
tool_call_message.displayed_tool_name = update.toolCallId
|
||||
end
|
||||
end
|
||||
tool_call_message.tool_use_logs = tool_call_message.tool_use_logs or {}
|
||||
tool_call_message.tool_use_log_lines = tool_call_message.tool_use_log_lines or {}
|
||||
local tool_result_message
|
||||
if update.status == "pending" or update.status == "in_progress" then
|
||||
tool_call_message.is_calling = true
|
||||
tool_call_message.state = "generating"
|
||||
else
|
||||
tool_call_message.is_calling = false
|
||||
tool_call_message.state = "generated"
|
||||
tool_result_message = History.Message:new("assistant", {
|
||||
type = "tool_result",
|
||||
tool_use_id = update.toolCallId,
|
||||
content = update.content,
|
||||
is_error = update.status == "failed",
|
||||
is_user_declined = update.status == "cancelled",
|
||||
})
|
||||
end
|
||||
local messages = { tool_call_message }
|
||||
if tool_result_message then table.insert(messages, tool_result_message) end
|
||||
on_messages_add(messages)
|
||||
end
|
||||
end,
|
||||
on_request_permission = function(tool_call, options, callback)
|
||||
local message = add_tool_call_message(tool_call)
|
||||
local items = vim
|
||||
.iter(options)
|
||||
:map(
|
||||
function(item)
|
||||
return {
|
||||
id = item.optionId,
|
||||
title = item.name,
|
||||
}
|
||||
end
|
||||
)
|
||||
:totable()
|
||||
local default_item = vim.iter(items):find(function(item) return item.id == options[1].optionId end)
|
||||
|
||||
local function on_select(item_ids)
|
||||
if not item_ids then return end
|
||||
local choice = vim.iter(items):find(function(item) return item.id == item_ids[1] end)
|
||||
if not choice then return end
|
||||
Utils.debug("on_select", choice.id)
|
||||
callback(choice.id)
|
||||
end
|
||||
|
||||
local selector = Selector:new({
|
||||
title = message.displayed_tool_name or message.message.content[1].name,
|
||||
items = items,
|
||||
default_item_id = default_item and default_item.name or nil,
|
||||
provider = Config.selector.provider,
|
||||
provider_opts = Config.selector.provider_opts,
|
||||
on_select = on_select,
|
||||
get_preview_content = function(_)
|
||||
local file_content = ""
|
||||
local filetype = "text"
|
||||
local content = tool_call.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "content" then
|
||||
if type(item.content) == "table" then
|
||||
if item.content.type == "text" then
|
||||
file_content = file_content .. item.content.text .. "\n\n"
|
||||
end
|
||||
end
|
||||
end
|
||||
if item.type == "diff" then
|
||||
local unified_diff = Utils.get_unified_diff(item.oldText, item.newText, { algorithm = "myers" })
|
||||
local result = "--- a/" .. item.path .. "\n+++ b/" .. item.path .. "\n" .. unified_diff .. "\n\n"
|
||||
filetype = "diff"
|
||||
file_content = file_content .. result
|
||||
end
|
||||
end
|
||||
end
|
||||
return file_content, filetype
|
||||
end,
|
||||
})
|
||||
|
||||
vim.schedule(function() selector:open() end)
|
||||
end,
|
||||
on_read_file = function(path, line, limit)
|
||||
local abs_path = Utils.to_absolute_path(path)
|
||||
local lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
lines = lines or {}
|
||||
if line ~= nil and limit ~= nil then lines = vim.list_slice(lines, line, line + limit) end
|
||||
return table.concat(lines, "\n")
|
||||
end,
|
||||
on_write_file = function(path, content)
|
||||
local abs_path = Utils.to_absolute_path(path)
|
||||
local file = io.open(abs_path, "w")
|
||||
if file then
|
||||
file:write(content)
|
||||
file:close()
|
||||
return nil
|
||||
end
|
||||
return "Failed to write file: " .. abs_path
|
||||
end,
|
||||
},
|
||||
})
|
||||
acp_client = ACPClient:new(acp_config)
|
||||
acp_client:connect()
|
||||
opts.on_save_acp_client(acp_client)
|
||||
end
|
||||
local session_id = opts.acp_session_id
|
||||
if not session_id then
|
||||
local project_root = Utils.root.get()
|
||||
local session_id_, err = acp_client:create_session(project_root, {})
|
||||
if err then
|
||||
opts.on_stop({ reason = "error", error = err })
|
||||
return
|
||||
end
|
||||
if not session_id_ then
|
||||
opts.on_stop({ reason = "error", error = "Failed to create session" })
|
||||
return
|
||||
end
|
||||
session_id = session_id_
|
||||
opts.on_save_acp_session_id(session_id)
|
||||
end
|
||||
local prompt = {}
|
||||
local history_messages = opts.history_messages or {}
|
||||
if opts.acp_session_id then
|
||||
for i = #history_messages, 1, -1 do
|
||||
local message = history_messages[i]
|
||||
if message.message.role == "user" then
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item,
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "text" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item.text,
|
||||
})
|
||||
end
|
||||
end
|
||||
elseif type(content) == "string" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = content,
|
||||
})
|
||||
end
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
for _, message in ipairs(history_messages) do
|
||||
if message.message.role == "user" then
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item,
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "text" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item.text,
|
||||
})
|
||||
end
|
||||
end
|
||||
else
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if opts.selected_filepaths then
|
||||
for _, filepath in ipairs(opts.selected_filepaths) do
|
||||
local lines, error = Utils.read_file_from_buf_or_disk(filepath)
|
||||
if error ~= nil then
|
||||
Utils.error("error reading file: " .. error)
|
||||
else
|
||||
local abs_path = Utils.to_absolute_path(filepath)
|
||||
local content = table.concat(lines or {}, "\n")
|
||||
local filetype = Utils.get_filetype(filepath)
|
||||
local prompt_item = acp_client:create_resource_content({
|
||||
uri = "file://" .. abs_path,
|
||||
mimeType = "text/x-" .. filetype,
|
||||
text = content,
|
||||
}, nil)
|
||||
table.insert(prompt, prompt_item)
|
||||
end
|
||||
end
|
||||
end
|
||||
acp_client:send_prompt(session_id, prompt, function(_, err_)
|
||||
if err_ then
|
||||
opts.on_stop({ reason = "error", error = err_ })
|
||||
return
|
||||
end
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end)
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream(opts)
|
||||
-- Reset the cancellation flag at the start of a new request
|
||||
if LLMToolHelpers then LLMToolHelpers.is_cancelled = false end
|
||||
|
||||
local acp_provider = Config.acp_providers[Config.provider]
|
||||
if acp_provider then return M._stream_acp(opts) end
|
||||
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
opts.session_ctx = opts.session_ctx or {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user