feat: support acp (#2649)
This commit is contained in:
76
README.md
76
README.md
@@ -1150,6 +1150,82 @@ providers = {
|
||||
}
|
||||
```
|
||||
|
||||
## ACP Support
|
||||
|
||||
Avante.nvim now supports the [Agent Client Protocol (ACP)](https://agentclientprotocol.com/overview/introduction), enabling seamless integration with AI agents that follow this standardized communication protocol. ACP provides a unified way for AI agents to interact with development environments, offering enhanced capabilities for code editing, file operations, and tool execution.
|
||||
|
||||
### What is ACP?
|
||||
|
||||
The Agent Client Protocol (ACP) is a standardized protocol that enables AI agents to communicate with development tools and environments. It provides:
|
||||
|
||||
- **Standardized Communication**: A unified JSON-RPC based protocol for agent-client interactions
|
||||
- **Tool Integration**: Support for various development tools like file operations, code execution, and search
|
||||
- **Session Management**: Persistent sessions that maintain context across interactions
|
||||
- **Permission System**: Granular control over what agents can access and modify
|
||||
|
||||
### Enabling ACP
|
||||
|
||||
To use ACP-compatible agents with Avante.nvim, you need to configure an ACP provider. Here are the currently supported ACP agents:
|
||||
|
||||
#### Gemini CLI with ACP
|
||||
```lua
|
||||
require('avante').setup({
|
||||
provider = "gemini-cli",
|
||||
-- other configuration options...
|
||||
})
|
||||
```
|
||||
|
||||
#### Claude Code with ACP
|
||||
```lua
|
||||
require('avante').setup({
|
||||
provider = "claude-code",
|
||||
-- other configuration options...
|
||||
})
|
||||
```
|
||||
|
||||
### ACP Configuration
|
||||
|
||||
ACP providers are configured in the `acp_providers` section of your configuration:
|
||||
|
||||
```lua
|
||||
require('avante').setup({
|
||||
acp_providers = {
|
||||
["gemini-cli"] = {
|
||||
command = "gemini",
|
||||
args = { "--experimental-acp" },
|
||||
env = {
|
||||
NODE_NO_WARNINGS = "1",
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY"),
|
||||
},
|
||||
},
|
||||
["claude-code"] = {
|
||||
command = "npx",
|
||||
args = { "acp-claude-code" },
|
||||
env = {
|
||||
NODE_NO_WARNINGS = "1",
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before using ACP agents, ensure you have the required tools installed:
|
||||
|
||||
- **For Gemini CLI**: Install the `gemini` CLI tool and set your `GEMINI_API_KEY`
|
||||
- **For Claude Code**: Install the `acp-claude-code` package via npm and set your `ANTHROPIC_API_KEY`
|
||||
|
||||
### ACP vs Traditional Providers
|
||||
|
||||
ACP providers offer several advantages over traditional API-based providers:
|
||||
|
||||
- **Enhanced Tool Access**: Agents can directly interact with your file system, run commands, and access development tools
|
||||
- **Persistent Context**: Sessions maintain state across multiple interactions
|
||||
- **Fine-grained Permissions**: Control exactly what agents can access and modify
|
||||
- **Standardized Protocol**: Compatible with any ACP-compliant agent
|
||||
|
||||
## Custom providers
|
||||
|
||||
Avante provides a set of default providers, but users can also create their own providers.
|
||||
|
||||
@@ -237,6 +237,28 @@ M._defaults = {
|
||||
},
|
||||
},
|
||||
},
|
||||
acp_providers = {
|
||||
["gemini-cli"] = {
|
||||
command = "gemini",
|
||||
args = { "--experimental-acp" },
|
||||
env = {
|
||||
NODE_NO_WARNINGS = "1",
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY"),
|
||||
},
|
||||
auth_method = "gemini-api-key",
|
||||
},
|
||||
["claude-code"] = {
|
||||
command = "npx",
|
||||
args = { "acp-claude-code" },
|
||||
env = {
|
||||
NODE_NO_WARNINGS = "1",
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY"),
|
||||
ANTHROPIC_BASE_URL = os.getenv("ANTHROPIC_BASE_URL"),
|
||||
ACP_PATH_TO_CLAUDE_CODE_EXECUTABLE = vim.fn.exepath("claude"),
|
||||
ACP_PERMISSION_MODE = "bypassPermissions",
|
||||
},
|
||||
},
|
||||
},
|
||||
---To add support for custom provider, follow the format below
|
||||
---See https://github.com/yetone/avante.nvim/wiki#custom-providers for more details
|
||||
---@type {[string]: AvanteProvider}
|
||||
@@ -466,7 +488,7 @@ M._defaults = {
|
||||
use_cwd_as_project_root = false,
|
||||
auto_focus_on_diff_view = false,
|
||||
---@type boolean | string[] -- true: auto-approve all tools, false: normal prompts, string[]: auto-approve specific tools by name
|
||||
auto_approve_tool_permissions = false, -- Default: show permission prompts for all tools
|
||||
auto_approve_tool_permissions = true, -- Default: show permission prompts for all tools
|
||||
auto_check_diagnostics = true,
|
||||
enable_fastapply = false,
|
||||
},
|
||||
@@ -774,6 +796,7 @@ end
|
||||
local function apply_model_selection(config, model_name, provider_name)
|
||||
local provider_list = config.providers or {}
|
||||
local current_provider_name = config.provider
|
||||
if config.acp_providers[current_provider_name] then return end
|
||||
|
||||
local target_provider_name = provider_name or current_provider_name
|
||||
local target_provider = provider_list[target_provider_name]
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
local Helpers = require("avante.history.helpers")
|
||||
local Line = require("avante.ui.line")
|
||||
local Utils = require("avante.utils")
|
||||
local Highlights = require("avante.highlights")
|
||||
|
||||
local M = {}
|
||||
|
||||
---@diagnostic disable-next-line: deprecated
|
||||
local islist = vim.islist or vim.tbl_islist
|
||||
|
||||
---Converts text into format suitable for UI
|
||||
---@param text string
|
||||
---@param decoration string | nil
|
||||
---@return avante.ui.Line[]
|
||||
local function text_to_lines(text)
|
||||
local function text_to_lines(text, decoration)
|
||||
local text_lines = vim.split(text, "\n")
|
||||
local lines = {}
|
||||
for _, text_line in ipairs(text_lines) do
|
||||
table.insert(lines, Line:new({ { text_line } }))
|
||||
if decoration then
|
||||
table.insert(lines, Line:new({ { decoration }, { text_line } }))
|
||||
else
|
||||
table.insert(lines, Line:new({ { text_line } }))
|
||||
end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
@@ -22,6 +31,10 @@ end
|
||||
local function thinking_to_lines(item)
|
||||
local text = item.thinking or item.data or ""
|
||||
local text_lines = vim.split(text, "\n")
|
||||
--- trime suffix empty lines
|
||||
while #text_lines > 0 and text_lines[#text_lines] == "" do
|
||||
table.remove(text_lines, #text_lines)
|
||||
end
|
||||
local ui_lines = {}
|
||||
table.insert(ui_lines, Line:new({ { Utils.icon("🤔 ") .. "Thought content:" } }))
|
||||
table.insert(ui_lines, Line:new({ { "" } }))
|
||||
@@ -35,7 +48,7 @@ end
|
||||
---@param tool_name string
|
||||
---@param logs string[]
|
||||
---@return avante.ui.Line[]
|
||||
local function tool_logs_to_lines(tool_name, logs)
|
||||
function M.tool_logs_to_lines(tool_name, logs)
|
||||
local ui_lines = {}
|
||||
local num_logs = #logs
|
||||
|
||||
@@ -44,7 +57,7 @@ local function tool_logs_to_lines(tool_name, logs)
|
||||
local num_lines = #log_lines
|
||||
|
||||
for line_idx = 1, num_lines do
|
||||
local decoration = (log_idx == num_logs and line_idx == num_lines) and "╰─ " or "│ "
|
||||
local decoration = "│ "
|
||||
table.insert(ui_lines, Line:new({ { decoration }, { " " .. log_lines[line_idx] } }))
|
||||
end
|
||||
end
|
||||
@@ -57,14 +70,201 @@ local STATE_TO_HL = {
|
||||
succeeded = "AvanteStateSpinnerSucceeded",
|
||||
}
|
||||
|
||||
function M.get_diff_lines(old_str, new_str, decoration, truncate)
|
||||
local lines = {}
|
||||
local line_count = 0
|
||||
local old_lines = vim.split(old_str, "\n")
|
||||
local new_lines = vim.split(new_str, "\n")
|
||||
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
|
||||
local patch = vim.diff(old_str, new_str, { ---@type integer[][]
|
||||
algorithm = "histogram",
|
||||
result_type = "indices",
|
||||
ctxlen = vim.o.scrolloff,
|
||||
})
|
||||
local prev_start_a = 0
|
||||
for idx, hunk in ipairs(patch) do
|
||||
if truncate and line_count > 10 then
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({
|
||||
{ decoration },
|
||||
{ string.format("... (Result truncated, remaining %d hunks not shown)", #patch - idx + 1) },
|
||||
})
|
||||
)
|
||||
break
|
||||
end
|
||||
local start_a, count_a, start_b, count_b = unpack(hunk)
|
||||
local no_change_lines = vim.list_slice(old_lines, prev_start_a, start_a - 1)
|
||||
local last_tree_no_change_lines = vim.list_slice(no_change_lines, #no_change_lines - 3)
|
||||
if #no_change_lines > 3 then table.insert(lines, Line:new({ { decoration }, { "..." } })) end
|
||||
for _, line in ipairs(last_tree_no_change_lines) do
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, Line:new({ { decoration }, { line } }))
|
||||
end
|
||||
prev_start_a = start_a + count_a
|
||||
if count_a > 0 then
|
||||
local delete_lines = vim.list_slice(old_lines, start_a, start_a + count_a - 1)
|
||||
for _, line in ipairs(delete_lines) do
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, Line:new({ { decoration }, { line, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } }))
|
||||
end
|
||||
end
|
||||
if count_b > 0 then
|
||||
local create_lines = vim.list_slice(new_lines, start_b, start_b + count_b - 1)
|
||||
for _, line in ipairs(create_lines) do
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, Line:new({ { decoration }, { line, Highlights.INCOMING } }))
|
||||
end
|
||||
end
|
||||
end
|
||||
if prev_start_a < #old_lines then
|
||||
-- Append remaining old_lines
|
||||
local no_change_lines = vim.list_slice(old_lines, prev_start_a, #old_lines)
|
||||
local first_tree_no_change_lines = vim.list_slice(no_change_lines, 1, 3)
|
||||
for _, line in ipairs(first_tree_no_change_lines) do
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, Line:new({ { decoration }, { line } }))
|
||||
end
|
||||
if #no_change_lines > 3 then table.insert(lines, Line:new({ { decoration }, { "..." } })) end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---@param content any
|
||||
---@param decoration string | nil
|
||||
---@param truncate boolean | nil
|
||||
function M.get_content_lines(content, decoration, truncate)
|
||||
local lines = {}
|
||||
local content_obj = content
|
||||
if type(content) == "string" then
|
||||
local ok, content_obj_ = pcall(vim.json.decode, content)
|
||||
if ok then content_obj = content_obj_ end
|
||||
end
|
||||
if type(content_obj) == "table" then
|
||||
if islist(content_obj) then
|
||||
local line_count = 0
|
||||
local all_lines = {}
|
||||
for _, content_item in ipairs(content_obj) do
|
||||
if type(content_item) == "string" then
|
||||
local lines_ = text_to_lines(content_item, decoration)
|
||||
all_lines = vim.list_extend(all_lines, lines_)
|
||||
end
|
||||
end
|
||||
for idx, line in ipairs(all_lines) do
|
||||
if truncate and line_count > 3 then
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({
|
||||
{ decoration },
|
||||
{ string.format("... (Result truncated, remaining %d lines not shown)", #all_lines - idx + 1) },
|
||||
})
|
||||
)
|
||||
break
|
||||
end
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, line)
|
||||
end
|
||||
end
|
||||
if type(content_obj.content) == "string" then
|
||||
local line_count = 0
|
||||
local lines_ = text_to_lines(content_obj.content, decoration)
|
||||
for idx, line in ipairs(lines_) do
|
||||
if truncate and line_count > 3 then
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({
|
||||
{ decoration },
|
||||
{ string.format("... (Result truncated, remaining %d lines not shown)", #lines_ - idx + 1) },
|
||||
})
|
||||
)
|
||||
break
|
||||
end
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
if type(content_obj) == "string" then
|
||||
local lines_ = text_to_lines(content_obj, decoration)
|
||||
local line_count = 0
|
||||
for idx, line in ipairs(lines_) do
|
||||
if truncate and line_count > 3 then
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({
|
||||
{ decoration },
|
||||
{ string.format("... (Result truncated, remaining %d lines not shown)", #lines_ - idx + 1) },
|
||||
})
|
||||
)
|
||||
break
|
||||
end
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, line)
|
||||
end
|
||||
end
|
||||
if islist(content) then
|
||||
for _, content_item in ipairs(content) do
|
||||
local line_count = 0
|
||||
if content_item.type == "content" then
|
||||
if content_item.content.type == "text" then
|
||||
local lines_ = text_to_lines(content_item.content.text, decoration)
|
||||
for idx, line in ipairs(lines_) do
|
||||
if truncate and line_count > 3 then
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({
|
||||
{ decoration },
|
||||
{ string.format("... (Result truncated, remaining %d lines not shown)", #lines_ - idx + 1) },
|
||||
})
|
||||
)
|
||||
break
|
||||
end
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, line)
|
||||
end
|
||||
end
|
||||
elseif content_item.type == "diff" then
|
||||
table.insert(lines, Line:new({ { decoration }, { "Path: " .. content_item.path } }))
|
||||
local lines_ = M.get_diff_lines(content_item.oldText, content_item.newText, decoration, truncate)
|
||||
lines = vim.list_extend(lines, lines_)
|
||||
end
|
||||
end
|
||||
end
|
||||
return lines
|
||||
end
|
||||
|
||||
---Converts a tool invocation into format suitable for UI
|
||||
---@param item AvanteLLMMessageContentItem
|
||||
---@param message avante.HistoryMessage
|
||||
---@param messages avante.HistoryMessage[]
|
||||
---@param logs string[]|nil
|
||||
---@return avante.ui.Line[]
|
||||
local function tool_to_lines(item, messages, logs)
|
||||
local function tool_to_lines(item, message, messages)
|
||||
-- local logs = message.tool_use_logs
|
||||
local lines = {}
|
||||
|
||||
local tool_name = item.name
|
||||
|
||||
local rest_input_text_lines = {}
|
||||
|
||||
if message.displayed_tool_name then
|
||||
tool_name = message.displayed_tool_name
|
||||
else
|
||||
if item.input and type(item.input) == "table" then
|
||||
local param
|
||||
if type(item.input.path) == "string" then param = item.input.path end
|
||||
if type(item.input.rel_path) == "string" then param = item.input.rel_path end
|
||||
if type(item.input.filepath) == "string" then param = item.input.filepath end
|
||||
if type(item.input.query) == "string" then param = item.input.query end
|
||||
if type(item.input.pattern) == "string" then param = item.input.pattern end
|
||||
if type(item.input.command) == "string" then
|
||||
param = item.input.command
|
||||
local pieces = vim.split(param, "\n")
|
||||
if #pieces > 1 then param = pieces[1] .. "..." end
|
||||
end
|
||||
if param then tool_name = item.name .. "(" .. param .. ")" end
|
||||
end
|
||||
end
|
||||
|
||||
local result = Helpers.get_tool_result(item.id, messages)
|
||||
local state
|
||||
if not result then
|
||||
@@ -78,11 +278,47 @@ local function tool_to_lines(item, messages, logs)
|
||||
lines,
|
||||
Line:new({
|
||||
{ "╭─ " },
|
||||
{ " " .. item.name .. " ", STATE_TO_HL[state] },
|
||||
{ " " .. tool_name .. " ", STATE_TO_HL[state] },
|
||||
{ " " .. state },
|
||||
})
|
||||
)
|
||||
if logs then vim.list_extend(lines, tool_logs_to_lines(item.name, logs)) end
|
||||
-- if logs then vim.list_extend(lines, tool_logs_to_lines(item.name, logs)) end
|
||||
local decoration = "│ "
|
||||
if rest_input_text_lines and #rest_input_text_lines > 0 then
|
||||
local lines_ = text_to_lines(table.concat(rest_input_text_lines, "\n"), decoration)
|
||||
local line_count = 0
|
||||
for idx, line in ipairs(lines_) do
|
||||
if line_count > 3 then
|
||||
table.insert(
|
||||
lines,
|
||||
Line:new({
|
||||
{ decoration },
|
||||
{ string.format("... (Input truncated, remaining %d lines not shown)", #lines_ - idx + 1) },
|
||||
})
|
||||
)
|
||||
break
|
||||
end
|
||||
line_count = line_count + 1
|
||||
table.insert(lines, line)
|
||||
end
|
||||
table.insert(lines, Line:new({ { decoration }, { "" } }))
|
||||
end
|
||||
if item.input and type(item.input) == "table" then
|
||||
if type(item.input.old_str) == "string" and type(item.input.new_str) == "string" then
|
||||
local diff_lines = M.get_diff_lines(item.input.old_str, item.input.new_str, decoration, true)
|
||||
vim.list_extend(lines, diff_lines)
|
||||
end
|
||||
end
|
||||
if result and result.content then
|
||||
local result_content = result.content
|
||||
if result_content then
|
||||
local content_lines = M.get_content_lines(result_content, decoration, true)
|
||||
vim.list_extend(lines, content_lines)
|
||||
end
|
||||
end
|
||||
if #lines <= 1 then table.insert(lines, Line:new({ { decoration }, { "completed" } })) end
|
||||
local last_line = lines[#lines]
|
||||
last_line.sections[1][1] = "╰─ "
|
||||
return lines
|
||||
end
|
||||
|
||||
@@ -96,7 +332,7 @@ local function message_content_item_to_lines(item, message, messages)
|
||||
return text_to_lines(item)
|
||||
elseif type(item) == "table" then
|
||||
if item.type == "thinking" or item.type == "redacted_thinking" then
|
||||
return thinking_to_lines(item.thinking or item.data or "")
|
||||
return thinking_to_lines(item)
|
||||
elseif item.type == "text" then
|
||||
return text_to_lines(item.text)
|
||||
elseif item.type == "image" then
|
||||
@@ -116,7 +352,9 @@ local function message_content_item_to_lines(item, message, messages)
|
||||
end
|
||||
end
|
||||
|
||||
return tool_to_lines(item, messages, message.tool_use_logs)
|
||||
local lines = tool_to_lines(item, message, messages)
|
||||
if message.tool_use_log_lines then lines = vim.list_extend(lines, message.tool_use_log_lines) end
|
||||
return lines
|
||||
end
|
||||
end
|
||||
return {}
|
||||
|
||||
891
lua/avante/libs/acp_client.lua
Normal file
891
lua/avante/libs/acp_client.lua
Normal file
@@ -0,0 +1,891 @@
|
||||
---@class ClientCapabilities
|
||||
---@field fs FileSystemCapability
|
||||
|
||||
---@class FileSystemCapability
|
||||
---@field readTextFile boolean
|
||||
---@field writeTextFile boolean
|
||||
|
||||
---@class AgentCapabilities
|
||||
---@field loadSession boolean
|
||||
---@field promptCapabilities PromptCapabilities
|
||||
|
||||
---@class PromptCapabilities
|
||||
---@field image boolean
|
||||
---@field audio boolean
|
||||
---@field embeddedContext boolean
|
||||
|
||||
---@class AuthMethod
|
||||
---@field id string
|
||||
---@field name string
|
||||
---@field description string|nil
|
||||
|
||||
---@class McpServer
|
||||
---@field name string
|
||||
---@field command string
|
||||
---@field args string[]
|
||||
---@field env EnvVariable[]
|
||||
|
||||
---@class EnvVariable
|
||||
---@field name string
|
||||
---@field value string
|
||||
|
||||
---@alias StopReason "end_turn" | "max_tokens" | "max_turn_requests" | "refusal" | "cancelled"
|
||||
|
||||
---@alias ToolKind "read" | "edit" | "delete" | "move" | "search" | "execute" | "think" | "fetch" | "other"
|
||||
|
||||
---@alias ToolCallStatus "pending" | "in_progress" | "completed" | "failed"
|
||||
|
||||
---@alias PlanEntryStatus "pending" | "in_progress" | "completed"
|
||||
|
||||
---@alias PlanEntryPriority "high" | "medium" | "low"
|
||||
|
||||
---@class ContentBlock
|
||||
---@field type "text" | "image" | "audio" | "resource_link" | "resource"
|
||||
---@field annotations Annotations|nil
|
||||
|
||||
---@class TextContent : ContentBlock
|
||||
---@field type "text"
|
||||
---@field text string
|
||||
|
||||
---@class ImageContent : ContentBlock
|
||||
---@field type "image"
|
||||
---@field data string
|
||||
---@field mimeType string
|
||||
---@field uri string|nil
|
||||
|
||||
---@class AudioContent : ContentBlock
|
||||
---@field type "audio"
|
||||
---@field data string
|
||||
---@field mimeType string
|
||||
|
||||
---@class ResourceLinkContent : ContentBlock
|
||||
---@field type "resource_link"
|
||||
---@field uri string
|
||||
---@field name string
|
||||
---@field description string|nil
|
||||
---@field mimeType string|nil
|
||||
---@field size number|nil
|
||||
---@field title string|nil
|
||||
|
||||
---@class ResourceContent : ContentBlock
|
||||
---@field type "resource"
|
||||
---@field resource EmbeddedResource
|
||||
|
||||
---@class EmbeddedResource
|
||||
---@field uri string
|
||||
---@field text string|nil
|
||||
---@field blob string|nil
|
||||
---@field mimeType string|nil
|
||||
|
||||
---@class Annotations
|
||||
---@field audience any[]|nil
|
||||
---@field lastModified string|nil
|
||||
---@field priority number|nil
|
||||
|
||||
---@class ToolCall
|
||||
---@field toolCallId string
|
||||
---@field title string
|
||||
---@field kind ToolKind
|
||||
---@field status ToolCallStatus
|
||||
---@field content ToolCallContent[]
|
||||
---@field locations ToolCallLocation[]
|
||||
---@field rawInput table
|
||||
---@field rawOutput table
|
||||
|
||||
---@class ToolCallContent
|
||||
---@field type "content" | "diff"
|
||||
|
||||
---@class ToolCallContentBlock : ToolCallContent
|
||||
---@field type "content"
|
||||
---@field content ContentBlock
|
||||
|
||||
---@class ToolCallDiff : ToolCallContent
|
||||
---@field type "diff"
|
||||
---@field path string
|
||||
---@field oldText string|nil
|
||||
---@field newText string
|
||||
|
||||
---@class ToolCallLocation
|
||||
---@field path string
|
||||
---@field line number|nil
|
||||
|
||||
---@class PlanEntry
|
||||
---@field content string
|
||||
---@field priority PlanEntryPriority
|
||||
---@field status PlanEntryStatus
|
||||
|
||||
---@class Plan
|
||||
---@field entries PlanEntry[]
|
||||
|
||||
---@class SessionUpdate
|
||||
---@field sessionUpdate "user_message_chunk" | "agent_message_chunk" | "agent_thought_chunk" | "tool_call" | "tool_call_update" | "plan"
|
||||
|
||||
---@class UserMessageChunk : SessionUpdate
|
||||
---@field sessionUpdate "user_message_chunk"
|
||||
---@field content ContentBlock
|
||||
|
||||
---@class AgentMessageChunk : SessionUpdate
|
||||
---@field sessionUpdate "agent_message_chunk"
|
||||
---@field content ContentBlock
|
||||
|
||||
---@class AgentThoughtChunk : SessionUpdate
|
||||
---@field sessionUpdate "agent_thought_chunk"
|
||||
---@field content ContentBlock
|
||||
|
||||
---@class ToolCallUpdate : SessionUpdate
|
||||
---@field sessionUpdate "tool_call" | "tool_call_update"
|
||||
---@field toolCallId string
|
||||
---@field title string|nil
|
||||
---@field kind ToolKind|nil
|
||||
---@field status ToolCallStatus|nil
|
||||
---@field content ToolCallContent[]|nil
|
||||
---@field locations ToolCallLocation[]|nil
|
||||
---@field rawInput table|nil
|
||||
---@field rawOutput table|nil
|
||||
|
||||
---@class PlanUpdate : SessionUpdate
|
||||
---@field sessionUpdate "plan"
|
||||
---@field entries PlanEntry[]
|
||||
|
||||
---@class PermissionOption
|
||||
---@field optionId string
|
||||
---@field name string
|
||||
---@field kind "allow_once" | "allow_always" | "reject_once" | "reject_always"
|
||||
|
||||
---@class RequestPermissionOutcome
|
||||
---@field outcome "cancelled" | "selected"
|
||||
---@field optionId string|nil
|
||||
|
||||
---@class ACPTransport
|
||||
---@field send function
|
||||
---@field start function
|
||||
---@field stop function
|
||||
|
||||
---@alias ACPConnectionState "disconnected" | "connecting" | "connected" | "initializing" | "ready" | "error"
|
||||
|
||||
---@class ACPError
|
||||
---@field code number
|
||||
---@field message string
|
||||
---@field data any|nil
|
||||
|
||||
---@class ACPClient
|
||||
---@field protocol_version number
|
||||
---@field capabilities ClientCapabilities
|
||||
---@field agent_capabilities AgentCapabilities|nil
|
||||
---@field config ACPConfig
|
||||
---@field callbacks table<number, fun(result: table|nil, err: ACPError|nil)>
|
||||
local ACPClient = {}
|
||||
|
||||
-- ACP Error codes
|
||||
ACPClient.ERROR_CODES = {
|
||||
TRANSPORT_ERROR = -32000,
|
||||
PROTOCOL_ERROR = -32001,
|
||||
TIMEOUT_ERROR = -32002,
|
||||
AUTH_REQUIRED = -32003,
|
||||
SESSION_NOT_FOUND = -32004,
|
||||
PERMISSION_DENIED = -32005,
|
||||
INVALID_REQUEST = -32006,
|
||||
}
|
||||
|
||||
---@class ACPHandlers
|
||||
---@field on_session_update? function
|
||||
---@field on_request_permission? function
|
||||
---@field on_read_file? function
|
||||
---@field on_write_file? function
|
||||
---@field on_error? function
|
||||
|
||||
---@class ACPConfig
|
||||
---@field transport_type "stdio" | "websocket" | "tcp"
|
||||
---@field command? string Command to spawn agent (for stdio)
|
||||
---@field args? string[] Arguments for agent command
|
||||
---@field env? table Environment variables
|
||||
---@field host? string Host for tcp/websocket
|
||||
---@field port? number Port for tcp/websocket
|
||||
---@field timeout? number Request timeout in milliseconds
|
||||
---@field reconnect? boolean Enable auto-reconnect
|
||||
---@field max_reconnect_attempts? number Maximum reconnection attempts
|
||||
---@field heartbeat_interval? number Heartbeat interval in milliseconds
|
||||
---@field auth_method? string Authentication method
|
||||
---@field handlers? ACPHandlers
|
||||
---@field on_state_change? fun(new_state: ACPConnectionState, old_state: ACPConnectionState)
|
||||
|
||||
---Create a new ACP client instance
|
||||
---@param config ACPConfig
|
||||
---@return ACPClient
|
||||
function ACPClient:new(config)
|
||||
local client = setmetatable({
|
||||
id_counter = 0,
|
||||
protocol_version = 1,
|
||||
capabilities = {
|
||||
fs = {
|
||||
readTextFile = true,
|
||||
writeTextFile = true,
|
||||
},
|
||||
},
|
||||
pending_responses = {},
|
||||
callbacks = {},
|
||||
transport = nil,
|
||||
config = config or {},
|
||||
state = "disconnected",
|
||||
reconnect_count = 0,
|
||||
heartbeat_timer = nil,
|
||||
}, { __index = self })
|
||||
|
||||
client:_setup_transport()
|
||||
return client
|
||||
end
|
||||
|
||||
---Setup transport layer
|
||||
function ACPClient:_setup_transport()
|
||||
local transport_type = self.config.transport_type or "stdio"
|
||||
|
||||
if transport_type == "stdio" then
|
||||
self.transport = self:_create_stdio_transport()
|
||||
elseif transport_type == "websocket" then
|
||||
self.transport = self:_create_websocket_transport()
|
||||
elseif transport_type == "tcp" then
|
||||
self.transport = self:_create_tcp_transport()
|
||||
else
|
||||
error("Unsupported transport type: " .. transport_type)
|
||||
end
|
||||
end
|
||||
|
||||
---Set connection state
|
||||
---@param state ACPConnectionState
|
||||
function ACPClient:_set_state(state)
|
||||
local old_state = self.state
|
||||
self.state = state
|
||||
|
||||
if self.config.on_state_change then self.config.on_state_change(state, old_state) end
|
||||
end
|
||||
|
||||
---Create error object
|
||||
---@param code number
|
||||
---@param message string
|
||||
---@param data any?
|
||||
---@return ACPError
|
||||
function ACPClient:_create_error(code, message, data)
|
||||
return {
|
||||
code = code,
|
||||
message = message,
|
||||
data = data,
|
||||
}
|
||||
end
|
||||
|
||||
---Create stdio transport layer
|
||||
function ACPClient:_create_stdio_transport()
|
||||
local uv = vim.loop
|
||||
local transport = {
|
||||
stdin = nil,
|
||||
stdout = nil,
|
||||
process = nil,
|
||||
}
|
||||
|
||||
function transport.send(transport_self, data)
|
||||
if transport_self.stdin and not transport_self.stdin:is_closing() then
|
||||
transport_self.stdin:write(data .. "\n")
|
||||
return true
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
function transport.start(transport_self, on_message)
|
||||
self:_set_state("connecting")
|
||||
|
||||
local stdin = uv.new_pipe(false)
|
||||
local stdout = uv.new_pipe(false)
|
||||
local stderr = uv.new_pipe(false)
|
||||
|
||||
if not stdin or not stdout or not stderr then
|
||||
self:_set_state("error")
|
||||
error("Failed to create pipes for ACP agent")
|
||||
end
|
||||
|
||||
local args = vim.deepcopy(self.config.args or {})
|
||||
local env = self.config.env
|
||||
|
||||
-- Start with system environment and override with config env
|
||||
local final_env = {}
|
||||
|
||||
local path = vim.fn.getenv("PATH")
|
||||
if path then final_env[#final_env + 1] = "PATH=" .. path end
|
||||
|
||||
if env then
|
||||
for k, v in pairs(env) do
|
||||
final_env[#final_env + 1] = k .. "=" .. v
|
||||
end
|
||||
end
|
||||
|
||||
---@diagnostic disable-next-line: missing-fields
|
||||
local handle = uv.spawn(self.config.command, {
|
||||
args = args,
|
||||
env = final_env,
|
||||
stdio = { stdin, stdout, stderr },
|
||||
}, function(code, signal)
|
||||
vim.print("ACP agent exited with code " .. code .. " and signal " .. signal)
|
||||
self:_set_state("disconnected")
|
||||
|
||||
if transport_self.process then
|
||||
transport_self.process:close()
|
||||
transport_self.process = nil
|
||||
end
|
||||
|
||||
-- Handle auto-reconnect
|
||||
if self.config.reconnect and self.reconnect_count < (self.config.max_reconnect_attempts or 3) then
|
||||
self.reconnect_count = self.reconnect_count + 1
|
||||
vim.defer_fn(function()
|
||||
if self.state == "disconnected" then self:connect() end
|
||||
end, 2000) -- Wait 2 seconds before reconnecting
|
||||
end
|
||||
end)
|
||||
|
||||
if not handle then
|
||||
self:_set_state("error")
|
||||
error("Failed to spawn ACP agent process")
|
||||
end
|
||||
|
||||
transport_self.process = handle
|
||||
transport_self.stdin = stdin
|
||||
transport_self.stdout = stdout
|
||||
|
||||
self:_set_state("connected")
|
||||
|
||||
-- Read stdout
|
||||
local buffer = ""
|
||||
stdout:read_start(function(err, data)
|
||||
-- if data then vim.print("ACP stdout: " .. vim.inspect(data)) end
|
||||
if err then
|
||||
vim.notify("ACP stdout error: " .. err, vim.log.levels.ERROR)
|
||||
self:_set_state("error")
|
||||
return
|
||||
end
|
||||
|
||||
if data then
|
||||
buffer = buffer .. data
|
||||
|
||||
-- Split on newlines and process complete JSON-RPC messages
|
||||
local lines = vim.split(buffer, "\n", { plain = true })
|
||||
buffer = lines[#lines] -- Keep incomplete line in buffer
|
||||
|
||||
for i = 1, #lines - 1 do
|
||||
local line = vim.trim(lines[i])
|
||||
if line ~= "" then
|
||||
local ok, message = pcall(vim.json.decode, line)
|
||||
if ok then
|
||||
on_message(message)
|
||||
else
|
||||
vim.schedule(
|
||||
function() vim.notify("Failed to parse JSON-RPC message: " .. line, vim.log.levels.WARN) end
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end)
|
||||
|
||||
-- Read stderr for debugging
|
||||
stderr:read_start(function(_, data)
|
||||
if data then vim.schedule(function() vim.notify("ACP stderr: " .. data, vim.log.levels.DEBUG) end) end
|
||||
end)
|
||||
end
|
||||
|
||||
function transport.stop(transport_self)
|
||||
if transport_self.process then
|
||||
transport_self.process:close()
|
||||
transport_self.process = nil
|
||||
end
|
||||
if transport_self.stdin then
|
||||
transport_self.stdin:close()
|
||||
transport_self.stdin = nil
|
||||
end
|
||||
if transport_self.stdout then
|
||||
transport_self.stdout:close()
|
||||
transport_self.stdout = nil
|
||||
end
|
||||
self:_set_state("disconnected")
|
||||
end
|
||||
|
||||
return transport
|
||||
end
|
||||
|
||||
---Create WebSocket transport layer (placeholder)
|
||||
function ACPClient:_create_websocket_transport() error("WebSocket transport not implemented yet") end
|
||||
|
||||
---Create TCP transport layer (placeholder)
|
||||
function ACPClient:_create_tcp_transport() error("TCP transport not implemented yet") end
|
||||
|
||||
---Generate next request ID
|
||||
---@return number
|
||||
function ACPClient:_next_id()
|
||||
self.id_counter = self.id_counter + 1
|
||||
return self.id_counter
|
||||
end
|
||||
|
||||
---Send JSON-RPC request
|
||||
---@param method string
|
||||
---@param params table?
|
||||
---@param callback? fun(result: table|nil, err: ACPError|nil)
|
||||
---@return table|nil result
|
||||
---@return ACPError|nil err
|
||||
function ACPClient:_send_request(method, params, callback)
|
||||
local id = self:_next_id()
|
||||
local message = {
|
||||
jsonrpc = "2.0",
|
||||
id = id,
|
||||
method = method,
|
||||
params = params or {},
|
||||
}
|
||||
|
||||
if callback then self.callbacks[id] = callback end
|
||||
|
||||
local data = vim.json.encode(message) .. "\n"
|
||||
if not self.transport:send(data) then return nil end
|
||||
|
||||
if not callback then return self:_wait_response(id) end
|
||||
end
|
||||
|
||||
function ACPClient:_wait_response(id)
|
||||
local start_time = vim.loop.now()
|
||||
local timeout = self.config.timeout or 100000
|
||||
|
||||
while vim.loop.now() - start_time < timeout do
|
||||
vim.wait(10)
|
||||
|
||||
if self.pending_responses[id] then
|
||||
local result, err = unpack(self.pending_responses[id])
|
||||
self.pending_responses[id] = nil
|
||||
return result, err
|
||||
end
|
||||
end
|
||||
|
||||
return nil, self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for response")
|
||||
end
|
||||
|
||||
---Send JSON-RPC notification
|
||||
---@param method string
|
||||
---@param params table?
|
||||
function ACPClient:_send_notification(method, params)
|
||||
local message = {
|
||||
jsonrpc = "2.0",
|
||||
method = method,
|
||||
params = params or {},
|
||||
}
|
||||
|
||||
local data = vim.json.encode(message) .. "\n"
|
||||
self.transport:send(data)
|
||||
end
|
||||
|
||||
---Send JSON-RPC result
|
||||
---@param id number
|
||||
---@param result table
|
||||
---@return nil
|
||||
function ACPClient:_send_result(id, result)
|
||||
local message = { jsonrpc = "2.0", id = id, result = result }
|
||||
|
||||
local data = vim.json.encode(message) .. "\n"
|
||||
-- vim.print("Sending result: " .. data)
|
||||
self.transport:send(data)
|
||||
end
|
||||
|
||||
---Send JSON-RPC error
|
||||
---@param id number
|
||||
---@param message string
|
||||
---@param code? number
|
||||
---@return nil
|
||||
function ACPClient:_send_error(id, message, code)
|
||||
code = code or self.ERROR_CODES.TRANSPORT_ERROR
|
||||
local msg = { jsonrpc = "2.0", id = id, error = { code = code, message = message } }
|
||||
|
||||
local data = vim.json.encode(msg) .. "\n"
|
||||
self.transport:send(data)
|
||||
end
|
||||
|
||||
---Handle received message
|
||||
---@param message table
|
||||
function ACPClient:_handle_message(message)
|
||||
-- vim.print("Received message: " .. vim.inspect(message))
|
||||
-- Check if this is a notification (has method but no id, or has both method and id for notifications)
|
||||
if message.method and not message.result and not message.error then
|
||||
-- This is a notification
|
||||
self:_handle_notification(message.id, message.method, message.params)
|
||||
elseif message.id and (message.result or message.error) then
|
||||
local callback = self.callbacks[message.id]
|
||||
if callback then
|
||||
callback(message.result, message.error)
|
||||
self.callbacks[message.id] = nil
|
||||
else
|
||||
self.pending_responses[message.id] = { message.result, message.error }
|
||||
end
|
||||
else
|
||||
-- Unknown message type
|
||||
vim.notify("Unknown message type: " .. vim.inspect(message), vim.log.levels.WARN)
|
||||
end
|
||||
end
|
||||
|
||||
---Handle notification
|
||||
---@param method string
|
||||
---@param params table
|
||||
function ACPClient:_handle_notification(message_id, method, params)
|
||||
if method == "session/update" then
|
||||
self:_handle_session_update(params)
|
||||
elseif method == "session/request_permission" then
|
||||
self:_handle_request_permission(message_id, params)
|
||||
elseif method == "fs/read_text_file" then
|
||||
self:_handle_read_text_file(message_id, params)
|
||||
elseif method == "fs/write_text_file" then
|
||||
self:_handle_write_text_file(message_id, params)
|
||||
else
|
||||
vim.notify("Unknown notification method: " .. method, vim.log.levels.WARN)
|
||||
end
|
||||
end
|
||||
|
||||
---Handle session update notification
|
||||
---@param params table
|
||||
function ACPClient:_handle_session_update(params)
|
||||
local session_id = params.sessionId
|
||||
local update = params.update
|
||||
|
||||
if not session_id then
|
||||
vim.notify("Received session/update without sessionId", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
if not update then
|
||||
vim.notify("Received session/update without update data", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
if self.config.handlers and self.config.handlers.on_session_update then
|
||||
self.config.handlers.on_session_update(update)
|
||||
end
|
||||
end
|
||||
|
||||
---Handle permission request notification
|
||||
---@param message_id number
|
||||
---@param params table
|
||||
function ACPClient:_handle_request_permission(message_id, params)
|
||||
local session_id = params.sessionId
|
||||
local tool_call = params.toolCall
|
||||
local options = params.options
|
||||
|
||||
if not session_id or not tool_call then return end
|
||||
|
||||
if self.config.handlers and self.config.handlers.on_request_permission then
|
||||
self.config.handlers.on_request_permission(
|
||||
tool_call,
|
||||
options,
|
||||
function(option_id)
|
||||
self:_send_result(message_id, {
|
||||
outcome = {
|
||||
outcome = "selected",
|
||||
optionId = option_id,
|
||||
},
|
||||
})
|
||||
end
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
---Handle fs/read_text_file requests
|
||||
---@param message_id number
|
||||
---@param params table
|
||||
function ACPClient:_handle_read_text_file(message_id, params)
|
||||
local session_id = params.sessionId
|
||||
local path = params.path
|
||||
|
||||
if not session_id or not path then return end
|
||||
|
||||
if self.config.handlers and self.config.handlers.on_read_file then
|
||||
local content = self.config.handlers.on_read_file(path)
|
||||
self:_send_result(message_id, { content = content })
|
||||
end
|
||||
end
|
||||
|
||||
---Handle fs/write_text_file requests
|
||||
---@param message_id number
|
||||
---@param params table
|
||||
function ACPClient:_handle_write_text_file(message_id, params)
|
||||
local session_id = params.sessionId
|
||||
local path = params.path
|
||||
local content = params.content
|
||||
|
||||
if not session_id or not path or not content then return end
|
||||
|
||||
if self.config.handlers and self.config.handlers.on_write_file then
|
||||
local error = self.config.handlers.on_write_file(path, content)
|
||||
self:_send_result(message_id, error == nil and vim.NIL or error)
|
||||
end
|
||||
end
|
||||
|
||||
---Start client
|
||||
function ACPClient:connect()
|
||||
if self.state ~= "disconnected" then return end
|
||||
|
||||
self.transport:start(function(message) self:_handle_message(message) end)
|
||||
|
||||
self:initialize()
|
||||
end
|
||||
|
||||
---Stop client
|
||||
function ACPClient:stop()
|
||||
self.transport:stop()
|
||||
|
||||
self.pending_responses = {}
|
||||
self.reconnect_count = 0
|
||||
end
|
||||
|
||||
---Initialize protocol connection
|
||||
function ACPClient:initialize()
|
||||
if self.state ~= "connected" then
|
||||
local error = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Cannot initialize: client not connected")
|
||||
return error
|
||||
end
|
||||
|
||||
self:_set_state("initializing")
|
||||
|
||||
local result = self:_send_request("initialize", {
|
||||
protocolVersion = self.protocol_version,
|
||||
clientCapabilities = self.capabilities,
|
||||
})
|
||||
|
||||
if not result then
|
||||
self:_set_state("error")
|
||||
vim.notify("Failed to initialize", vim.log.levels.ERROR)
|
||||
return
|
||||
end
|
||||
|
||||
-- Update protocol version and capabilities
|
||||
self.protocol_version = result.protocolVersion
|
||||
self.agent_capabilities = result.agentCapabilities
|
||||
self.auth_methods = result.authMethods or {}
|
||||
|
||||
-- Check if we need to authenticate
|
||||
local auth_method = self.config.auth_method
|
||||
|
||||
if auth_method then
|
||||
vim.print("Authenticating with method " .. auth_method)
|
||||
self:authenticate(auth_method)
|
||||
self:_set_state("ready")
|
||||
else
|
||||
vim.print("No authentication method found or specified")
|
||||
self:_set_state("ready")
|
||||
end
|
||||
end
|
||||
|
||||
---Authentication (if required)
|
||||
---@param method_id string
|
||||
function ACPClient:authenticate(method_id)
|
||||
return self:_send_request("authenticate", {
|
||||
methodId = method_id,
|
||||
})
|
||||
end
|
||||
|
||||
---Create new session
|
||||
---@param cwd string
|
||||
---@param mcp_servers table[]?
|
||||
---@return string|nil session_id
|
||||
---@return ACPError|nil err
|
||||
function ACPClient:create_session(cwd, mcp_servers)
|
||||
local result, err = self:_send_request("session/new", {
|
||||
cwd = cwd,
|
||||
mcpServers = mcp_servers or {},
|
||||
})
|
||||
if err then
|
||||
vim.notify("Failed to create session: " .. err.message, vim.log.levels.ERROR)
|
||||
return nil, err
|
||||
end
|
||||
if not result then
|
||||
err = self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Failed to create session: missing result")
|
||||
return nil, err
|
||||
end
|
||||
return result.sessionId, nil
|
||||
end
|
||||
|
||||
---Load existing session
|
||||
---@param session_id string
|
||||
---@param cwd string
|
||||
---@param mcp_servers table[]?
|
||||
---@return table|nil result
|
||||
function ACPClient:load_session(session_id, cwd, mcp_servers)
|
||||
if not self.agent_capabilities or not self.agent_capabilities.loadSession then
|
||||
vim.notify("Agent does not support loading sessions", vim.log.levels.WARN)
|
||||
return
|
||||
end
|
||||
|
||||
return self:_send_request("session/load", {
|
||||
sessionId = session_id,
|
||||
cwd = cwd,
|
||||
mcpServers = mcp_servers or {},
|
||||
})
|
||||
end
|
||||
|
||||
---Send prompt
|
||||
---@param session_id string
|
||||
---@param prompt table[]
|
||||
---@param callback? fun(result: table|nil, err: ACPError|nil)
|
||||
function ACPClient:send_prompt(session_id, prompt, callback)
|
||||
local params = {
|
||||
sessionId = session_id,
|
||||
prompt = prompt,
|
||||
}
|
||||
return self:_send_request("session/prompt", params, callback)
|
||||
end
|
||||
|
||||
---Cancel session
|
||||
---@param session_id string
|
||||
function ACPClient:cancel_session(session_id)
|
||||
self:_send_notification("session/cancel", {
|
||||
sessionId = session_id,
|
||||
})
|
||||
end
|
||||
|
||||
---Helper function: Create text content block
|
||||
---@param text string
|
||||
---@param annotations table?
|
||||
---@return table
|
||||
function ACPClient:create_text_content(text, annotations)
|
||||
return {
|
||||
type = "text",
|
||||
text = text,
|
||||
annotations = annotations,
|
||||
}
|
||||
end
|
||||
|
||||
---Helper function: Create image content block
|
||||
---@param data string Base64 encoded image data
|
||||
---@param mime_type string
|
||||
---@param uri string?
|
||||
---@param annotations table?
|
||||
---@return table
|
||||
function ACPClient:create_image_content(data, mime_type, uri, annotations)
|
||||
return {
|
||||
type = "image",
|
||||
data = data,
|
||||
mimeType = mime_type,
|
||||
uri = uri,
|
||||
annotations = annotations,
|
||||
}
|
||||
end
|
||||
|
||||
---Helper function: Create audio content block
|
||||
---@param data string Base64 encoded audio data
|
||||
---@param mime_type string
|
||||
---@param annotations table?
|
||||
---@return table
|
||||
function ACPClient:create_audio_content(data, mime_type, annotations)
|
||||
return {
|
||||
type = "audio",
|
||||
data = data,
|
||||
mimeType = mime_type,
|
||||
annotations = annotations,
|
||||
}
|
||||
end
|
||||
|
||||
---Helper function: Create resource link content block
|
||||
---@param uri string
|
||||
---@param name string
|
||||
---@param description string?
|
||||
---@param mime_type string?
|
||||
---@param size number?
|
||||
---@param title string?
|
||||
---@param annotations table?
|
||||
---@return table
|
||||
function ACPClient:create_resource_link_content(uri, name, description, mime_type, size, title, annotations)
|
||||
return {
|
||||
type = "resource_link",
|
||||
uri = uri,
|
||||
name = name,
|
||||
description = description,
|
||||
mimeType = mime_type,
|
||||
size = size,
|
||||
title = title,
|
||||
annotations = annotations,
|
||||
}
|
||||
end
|
||||
|
||||
---Helper function: Create embedded resource content block
|
||||
---@param resource table
|
||||
---@param annotations table?
|
||||
---@return table
|
||||
function ACPClient:create_resource_content(resource, annotations)
|
||||
return {
|
||||
type = "resource",
|
||||
resource = resource,
|
||||
annotations = annotations,
|
||||
}
|
||||
end
|
||||
|
||||
---Helper function: Create text resource
|
||||
---@param uri string
|
||||
---@param text string
|
||||
---@param mime_type string?
|
||||
---@return table
|
||||
function ACPClient:create_text_resource(uri, text, mime_type)
|
||||
return {
|
||||
uri = uri,
|
||||
text = text,
|
||||
mimeType = mime_type,
|
||||
}
|
||||
end
|
||||
|
||||
---Helper function: Create binary resource
|
||||
---@param uri string
|
||||
---@param blob string Base64 encoded binary data
|
||||
---@param mime_type string?
|
||||
---@return table
|
||||
function ACPClient:create_blob_resource(uri, blob, mime_type)
|
||||
return {
|
||||
uri = uri,
|
||||
blob = blob,
|
||||
mimeType = mime_type,
|
||||
}
|
||||
end
|
||||
|
||||
---Convenience method: Check if client is ready
|
||||
---@return boolean
|
||||
function ACPClient:is_ready() return self.state == "ready" end
|
||||
|
||||
---Convenience method: Check if client is connected
|
||||
---@return boolean
|
||||
function ACPClient:is_connected() return self.state ~= "disconnected" and self.state ~= "error" end
|
||||
|
||||
---Convenience method: Get current state
|
||||
---@return ACPConnectionState
|
||||
function ACPClient:get_state() return self.state end
|
||||
|
||||
---Convenience method: Wait for client to be ready
|
||||
---@param callback function
|
||||
---@param timeout number? Timeout in milliseconds
|
||||
function ACPClient:wait_ready(callback, timeout)
|
||||
if self:is_ready() then
|
||||
callback(nil)
|
||||
return
|
||||
end
|
||||
|
||||
local timeout_ms = timeout or 10000 -- 10 seconds default
|
||||
local start_time = vim.loop.now()
|
||||
|
||||
local function check_ready()
|
||||
if self:is_ready() then
|
||||
callback(nil)
|
||||
elseif self.state == "error" then
|
||||
callback(self:_create_error(self.ERROR_CODES.PROTOCOL_ERROR, "Client entered error state while waiting"))
|
||||
elseif vim.loop.now() - start_time > timeout_ms then
|
||||
callback(self:_create_error(self.ERROR_CODES.TIMEOUT_ERROR, "Timeout waiting for client to be ready"))
|
||||
else
|
||||
vim.defer_fn(check_ready, 100) -- Check every 100ms
|
||||
end
|
||||
end
|
||||
|
||||
check_ready()
|
||||
end
|
||||
|
||||
---Convenience method: Send simple text prompt
|
||||
---@param session_id string
|
||||
---@param text string
|
||||
function ACPClient:send_text_prompt(session_id, text)
|
||||
local prompt = { self:create_text_content(text) }
|
||||
self:send_prompt(session_id, prompt)
|
||||
end
|
||||
|
||||
return ACPClient
|
||||
@@ -3,6 +3,7 @@ local fn = vim.fn
|
||||
local uv = vim.uv
|
||||
|
||||
local curl = require("plenary.curl")
|
||||
local ACPClient = require("avante.libs.acp_client")
|
||||
|
||||
local Utils = require("avante.utils")
|
||||
local Prompts = require("avante.utils.prompts")
|
||||
@@ -13,6 +14,7 @@ local Providers = require("avante.providers")
|
||||
local LLMToolHelpers = require("avante.llm_tools.helpers")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
local History = require("avante.history")
|
||||
local Selector = require("avante.ui.selector")
|
||||
|
||||
---@class avante.LLM
|
||||
local M = {}
|
||||
@@ -486,6 +488,7 @@ end
|
||||
---@param opts AvanteGeneratePromptsOptions
|
||||
---@return integer
|
||||
function M.calculate_tokens(opts)
|
||||
if Config.acp_providers[Config.provider] then return 0 end
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt)
|
||||
for _, message in ipairs(prompt_opts.messages) do
|
||||
@@ -795,11 +798,339 @@ local function stop_retry_timer()
|
||||
end
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream_acp(opts)
|
||||
---@type table<string, avante.HistoryMessage>
|
||||
local tool_call_messages = {}
|
||||
local acp_provider = Config.acp_providers[Config.provider]
|
||||
local on_messages_add = function(messages)
|
||||
if opts.on_messages_add then opts.on_messages_add(messages) end
|
||||
vim.schedule(function() vim.cmd("redraw") end)
|
||||
end
|
||||
local function add_tool_call_message(update)
|
||||
local message = History.Message:new("assistant", {
|
||||
type = "tool_use",
|
||||
id = update.toolCallId,
|
||||
name = update.kind .. "(" .. update.title .. ")",
|
||||
})
|
||||
if update.status == "pending" or update.status == "in_progress" then message.is_calling = true end
|
||||
tool_call_messages[update.toolCallId] = message
|
||||
if update.rawInput then
|
||||
local path = update.rawInput.path or update.rawInput.file_path
|
||||
if path then
|
||||
local relative_path = Utils.relative_path(path)
|
||||
message.displayed_tool_name = update.title .. "(" .. relative_path .. ")"
|
||||
end
|
||||
local pattern = update.rawInput.pattern or update.rawInput.search
|
||||
if pattern then message.displayed_tool_name = update.title .. "(" .. pattern .. ")" end
|
||||
local command = update.rawInput.command or update.rawInput.command_line
|
||||
if command then message.displayed_tool_name = update.title .. "(" .. command .. ")" end
|
||||
local description = update.rawInput.description
|
||||
if description then
|
||||
message.tool_use_logs = message.tool_use_logs or {}
|
||||
table.insert(message.tool_use_logs, description)
|
||||
end
|
||||
end
|
||||
on_messages_add({ message })
|
||||
return message
|
||||
end
|
||||
local acp_client = opts.acp_client
|
||||
if not acp_client then
|
||||
local acp_config = vim.tbl_deep_extend("force", acp_provider, {
|
||||
handlers = {
|
||||
on_session_update = function(update)
|
||||
if update.sessionUpdate == "plan" then
|
||||
local todos = {}
|
||||
for idx, entry in ipairs(update.entries) do
|
||||
local status = "todo"
|
||||
if entry.status == "in_progress" then status = "doing" end
|
||||
if entry.status == "completed" then status = "done" end
|
||||
---@type avante.TODO
|
||||
local todo = {
|
||||
id = tostring(idx),
|
||||
content = entry.content,
|
||||
status = status,
|
||||
priority = entry.priority,
|
||||
}
|
||||
table.insert(todos, todo)
|
||||
end
|
||||
vim.schedule(function()
|
||||
if opts.update_todos then opts.update_todos(todos) end
|
||||
end)
|
||||
return
|
||||
end
|
||||
if update.sessionUpdate == "agent_message_chunk" then
|
||||
if update.content.type == "text" then
|
||||
local messages = opts.get_history_messages()
|
||||
local last_message = messages[#messages]
|
||||
if last_message and last_message.message.role == "assistant" then
|
||||
local has_text = false
|
||||
local content = last_message.message.content
|
||||
if type(content) == "string" then
|
||||
last_message.message.content = last_message.message.content .. update.content.text
|
||||
has_text = true
|
||||
elseif type(content) == "table" then
|
||||
for idx, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
content[idx] = item .. update.content.text
|
||||
has_text = true
|
||||
end
|
||||
if type(item) == "table" and item.type == "text" then
|
||||
item.text = item.text .. update.content.text
|
||||
has_text = true
|
||||
end
|
||||
end
|
||||
end
|
||||
if has_text then
|
||||
on_messages_add({ last_message })
|
||||
return
|
||||
end
|
||||
end
|
||||
local message = History.Message:new("assistant", update.content.text)
|
||||
on_messages_add({ message })
|
||||
end
|
||||
end
|
||||
if update.sessionUpdate == "agent_thought_chunk" then
|
||||
if update.content.type == "text" then
|
||||
local message = History.Message:new("assistant", {
|
||||
type = "thinking",
|
||||
thinking = update.content.text,
|
||||
})
|
||||
on_messages_add({ message })
|
||||
end
|
||||
end
|
||||
if update.sessionUpdate == "tool_call" then add_tool_call_message(update) end
|
||||
if update.sessionUpdate == "tool_call_update" then
|
||||
local tool_call_message = tool_call_messages[update.toolCallId]
|
||||
if not tool_call_message then
|
||||
tool_call_message = History.Message:new("assistant", {
|
||||
type = "tool_use",
|
||||
id = update.toolCallId,
|
||||
name = "",
|
||||
})
|
||||
local update_content = update.content
|
||||
if type(update_content) == "table" then
|
||||
for _, item in ipairs(update_content) do
|
||||
if item.path then
|
||||
local relative_path = Utils.relative_path(item.path)
|
||||
tool_call_message.displayed_tool_name = "Edit(" .. relative_path .. ")"
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
if not tool_call_message.displayed_tool_name then
|
||||
tool_call_message.displayed_tool_name = update.toolCallId
|
||||
end
|
||||
end
|
||||
tool_call_message.tool_use_logs = tool_call_message.tool_use_logs or {}
|
||||
tool_call_message.tool_use_log_lines = tool_call_message.tool_use_log_lines or {}
|
||||
local tool_result_message
|
||||
if update.status == "pending" or update.status == "in_progress" then
|
||||
tool_call_message.is_calling = true
|
||||
tool_call_message.state = "generating"
|
||||
else
|
||||
tool_call_message.is_calling = false
|
||||
tool_call_message.state = "generated"
|
||||
tool_result_message = History.Message:new("assistant", {
|
||||
type = "tool_result",
|
||||
tool_use_id = update.toolCallId,
|
||||
content = update.content,
|
||||
is_error = update.status == "failed",
|
||||
is_user_declined = update.status == "cancelled",
|
||||
})
|
||||
end
|
||||
local messages = { tool_call_message }
|
||||
if tool_result_message then table.insert(messages, tool_result_message) end
|
||||
on_messages_add(messages)
|
||||
end
|
||||
end,
|
||||
on_request_permission = function(tool_call, options, callback)
|
||||
local message = add_tool_call_message(tool_call)
|
||||
local items = vim
|
||||
.iter(options)
|
||||
:map(
|
||||
function(item)
|
||||
return {
|
||||
id = item.optionId,
|
||||
title = item.name,
|
||||
}
|
||||
end
|
||||
)
|
||||
:totable()
|
||||
local default_item = vim.iter(items):find(function(item) return item.id == options[1].optionId end)
|
||||
|
||||
local function on_select(item_ids)
|
||||
if not item_ids then return end
|
||||
local choice = vim.iter(items):find(function(item) return item.id == item_ids[1] end)
|
||||
if not choice then return end
|
||||
Utils.debug("on_select", choice.id)
|
||||
callback(choice.id)
|
||||
end
|
||||
|
||||
local selector = Selector:new({
|
||||
title = message.displayed_tool_name or message.message.content[1].name,
|
||||
items = items,
|
||||
default_item_id = default_item and default_item.name or nil,
|
||||
provider = Config.selector.provider,
|
||||
provider_opts = Config.selector.provider_opts,
|
||||
on_select = on_select,
|
||||
get_preview_content = function(_)
|
||||
local file_content = ""
|
||||
local filetype = "text"
|
||||
local content = tool_call.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "content" then
|
||||
if type(item.content) == "table" then
|
||||
if item.content.type == "text" then
|
||||
file_content = file_content .. item.content.text .. "\n\n"
|
||||
end
|
||||
end
|
||||
end
|
||||
if item.type == "diff" then
|
||||
local unified_diff = Utils.get_unified_diff(item.oldText, item.newText, { algorithm = "myers" })
|
||||
local result = "--- a/" .. item.path .. "\n+++ b/" .. item.path .. "\n" .. unified_diff .. "\n\n"
|
||||
filetype = "diff"
|
||||
file_content = file_content .. result
|
||||
end
|
||||
end
|
||||
end
|
||||
return file_content, filetype
|
||||
end,
|
||||
})
|
||||
|
||||
vim.schedule(function() selector:open() end)
|
||||
end,
|
||||
on_read_file = function(path, line, limit)
|
||||
local abs_path = Utils.to_absolute_path(path)
|
||||
local lines = Utils.read_file_from_buf_or_disk(abs_path)
|
||||
lines = lines or {}
|
||||
if line ~= nil and limit ~= nil then lines = vim.list_slice(lines, line, line + limit) end
|
||||
return table.concat(lines, "\n")
|
||||
end,
|
||||
on_write_file = function(path, content)
|
||||
local abs_path = Utils.to_absolute_path(path)
|
||||
local file = io.open(abs_path, "w")
|
||||
if file then
|
||||
file:write(content)
|
||||
file:close()
|
||||
return nil
|
||||
end
|
||||
return "Failed to write file: " .. abs_path
|
||||
end,
|
||||
},
|
||||
})
|
||||
acp_client = ACPClient:new(acp_config)
|
||||
acp_client:connect()
|
||||
opts.on_save_acp_client(acp_client)
|
||||
end
|
||||
local session_id = opts.acp_session_id
|
||||
if not session_id then
|
||||
local project_root = Utils.root.get()
|
||||
local session_id_, err = acp_client:create_session(project_root, {})
|
||||
if err then
|
||||
opts.on_stop({ reason = "error", error = err })
|
||||
return
|
||||
end
|
||||
if not session_id_ then
|
||||
opts.on_stop({ reason = "error", error = "Failed to create session" })
|
||||
return
|
||||
end
|
||||
session_id = session_id_
|
||||
opts.on_save_acp_session_id(session_id)
|
||||
end
|
||||
local prompt = {}
|
||||
local history_messages = opts.history_messages or {}
|
||||
if opts.acp_session_id then
|
||||
for i = #history_messages, 1, -1 do
|
||||
local message = history_messages[i]
|
||||
if message.message.role == "user" then
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item,
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "text" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item.text,
|
||||
})
|
||||
end
|
||||
end
|
||||
elseif type(content) == "string" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = content,
|
||||
})
|
||||
end
|
||||
break
|
||||
end
|
||||
end
|
||||
else
|
||||
for _, message in ipairs(history_messages) do
|
||||
if message.message.role == "user" then
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item,
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "text" then
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = item.text,
|
||||
})
|
||||
end
|
||||
end
|
||||
else
|
||||
table.insert(prompt, {
|
||||
type = "text",
|
||||
text = content,
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
if opts.selected_filepaths then
|
||||
for _, filepath in ipairs(opts.selected_filepaths) do
|
||||
local lines, error = Utils.read_file_from_buf_or_disk(filepath)
|
||||
if error ~= nil then
|
||||
Utils.error("error reading file: " .. error)
|
||||
else
|
||||
local abs_path = Utils.to_absolute_path(filepath)
|
||||
local content = table.concat(lines or {}, "\n")
|
||||
local filetype = Utils.get_filetype(filepath)
|
||||
local prompt_item = acp_client:create_resource_content({
|
||||
uri = "file://" .. abs_path,
|
||||
mimeType = "text/x-" .. filetype,
|
||||
text = content,
|
||||
}, nil)
|
||||
table.insert(prompt, prompt_item)
|
||||
end
|
||||
end
|
||||
end
|
||||
acp_client:send_prompt(session_id, prompt, function(_, err_)
|
||||
if err_ then
|
||||
opts.on_stop({ reason = "error", error = err_ })
|
||||
return
|
||||
end
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end)
|
||||
end
|
||||
|
||||
---@param opts AvanteLLMStreamOptions
|
||||
function M._stream(opts)
|
||||
-- Reset the cancellation flag at the start of a new request
|
||||
if LLMToolHelpers then LLMToolHelpers.is_cancelled = false end
|
||||
|
||||
local acp_provider = Config.acp_providers[Config.provider]
|
||||
if acp_provider then return M._stream_acp(opts) end
|
||||
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
opts.session_ctx = opts.session_ctx or {}
|
||||
|
||||
|
||||
@@ -18,9 +18,7 @@ M.description =
|
||||
|
||||
M.support_streaming = true
|
||||
|
||||
function M.enabled()
|
||||
return require("avante.config").mode == "agentic" and not require("avante.config").behaviour.enable_fastapply
|
||||
end
|
||||
function M.enabled() return false end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
@@ -184,17 +182,7 @@ function M.func(input, opts)
|
||||
-- Utils.debug("diff", diff)
|
||||
local err = [[No diff blocks found.
|
||||
|
||||
Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE blocks are in the correct order.
|
||||
|
||||
For example:
|
||||
```
|
||||
------- SEARCH
|
||||
[exact content to find]
|
||||
=======
|
||||
[new content to replace with]
|
||||
+++++++ REPLACE
|
||||
```
|
||||
]]
|
||||
Please make sure the diff is formatted correctly, and that the SEARCH/REPLACE blocks are in the correct order.]]
|
||||
return false, err
|
||||
end
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ M.name = "str_replace"
|
||||
M.description =
|
||||
"The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits."
|
||||
|
||||
function M.enabled() return false end
|
||||
function M.enabled()
|
||||
return require("avante.config").mode == "agentic" and not require("avante.config").behaviour.enable_fastapply
|
||||
end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
|
||||
@@ -203,6 +203,8 @@ M = setmetatable(M, {
|
||||
function M.setup()
|
||||
vim.g.avante_login = false
|
||||
|
||||
if Config.acp_providers[Config.provider] then return end
|
||||
|
||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
local provider = M[Config.provider]
|
||||
|
||||
|
||||
@@ -73,6 +73,8 @@ Sidebar.__index = Sidebar
|
||||
---@field input_hint_window integer | nil
|
||||
---@field old_result_lines avante.ui.Line[]
|
||||
---@field token_count integer | nil
|
||||
---@field acp_client ACPClient | nil
|
||||
---@field acp_session_id string | nil
|
||||
|
||||
---@param id integer the tabpage id retrieved from api.nvim_get_current_tabpage()
|
||||
function Sidebar:new(id)
|
||||
@@ -199,7 +201,6 @@ function Sidebar:set_code_winhl()
|
||||
if not Utils.is_valid_container(self.containers.result, true) then return end
|
||||
|
||||
if Utils.should_hidden_border(self.code.winid, self.containers.result.winid) then
|
||||
Utils.debug("setting winhl")
|
||||
local old_winhl = vim.wo[self.code.winid].winhl
|
||||
if self.code.old_winhl == nil then
|
||||
self.code.old_winhl = old_winhl
|
||||
@@ -1638,9 +1639,15 @@ end
|
||||
---@param selected_code AvanteSelectedCode?
|
||||
---@return string
|
||||
local function render_chat_record_prefix(timestamp, provider, model, request, selected_filepaths, selected_code)
|
||||
provider = provider or "unknown"
|
||||
model = model or "unknown"
|
||||
local res = "- Datetime: " .. timestamp .. "\n" .. "- Model: " .. provider .. "/" .. model
|
||||
local res
|
||||
local acp_provider = Config.acp_providers[provider]
|
||||
if acp_provider then
|
||||
res = "- Datetime: " .. timestamp .. "\n" .. "- ACP: " .. provider
|
||||
else
|
||||
provider = provider or "unknown"
|
||||
model = model or "unknown"
|
||||
res = "- Datetime: " .. timestamp .. "\n" .. "- Model: " .. provider .. "/" .. model
|
||||
end
|
||||
if selected_filepaths ~= nil and #selected_filepaths > 0 then
|
||||
res = res .. "\n- Selected files:"
|
||||
for _, path in ipairs(selected_filepaths) do
|
||||
@@ -1740,10 +1747,26 @@ local _message_to_lines_lru_cache = LRUCache:new(100)
|
||||
---@return avante.ui.Line[]
|
||||
local function get_message_lines(message, messages, ctx)
|
||||
if message.state == "generating" or message.is_calling then return _get_message_lines(message, messages, ctx) end
|
||||
local cached_lines = _message_to_lines_lru_cache:get(message.uuid)
|
||||
local text_len = 0
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) == "string" then
|
||||
text_len = text_len + #item
|
||||
else
|
||||
for _, subitem in ipairs(item) do
|
||||
if type(subitem) == "string" then text_len = text_len + #subitem end
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif type(content) == "string" then
|
||||
text_len = #content
|
||||
end
|
||||
local cache_key = message.uuid .. ":" .. tostring(text_len)
|
||||
local cached_lines = _message_to_lines_lru_cache:get(cache_key)
|
||||
if cached_lines then return cached_lines end
|
||||
local lines = _get_message_lines(message, messages, ctx)
|
||||
_message_to_lines_lru_cache:set(message.uuid, lines)
|
||||
_message_to_lines_lru_cache:set(cache_key, lines)
|
||||
return lines
|
||||
end
|
||||
|
||||
@@ -1934,6 +1957,7 @@ function Sidebar:render_state()
|
||||
if self.current_state == "thinking" then hl = "AvanteStateSpinnerThinking" end
|
||||
if self.current_state == "compacting" then hl = "AvanteStateSpinnerCompacting" end
|
||||
local spinner_char = spinner_chars[self.state_spinner_idx]
|
||||
if not spinner_char then spinner_char = spinner_chars[1] end
|
||||
self.state_spinner_idx = (self.state_spinner_idx % #spinner_chars) + 1
|
||||
if
|
||||
self.current_state ~= "generating"
|
||||
@@ -2004,6 +2028,7 @@ function Sidebar:new_chat(args, cb)
|
||||
Path.history.save(self.code.bufnr, history)
|
||||
self:reload_chat_history()
|
||||
self.current_state = nil
|
||||
self.acp_session_id = nil
|
||||
self:update_content("New chat", { focus = false, scroll = false, callback = function() self:focus_input() end })
|
||||
if cb then cb(args) end
|
||||
vim.schedule(function() self:create_todos_container() end)
|
||||
@@ -2035,13 +2060,16 @@ function Sidebar:update_todos(todos)
|
||||
end
|
||||
|
||||
---@param messages avante.HistoryMessage | avante.HistoryMessage[]
|
||||
function Sidebar:add_history_messages(messages)
|
||||
---@param opts? {eager_update?: boolean}
|
||||
function Sidebar:add_history_messages(messages, opts)
|
||||
local history_messages = History.get_history_messages(self.chat_history)
|
||||
messages = vim.islist(messages) and messages or { messages }
|
||||
for _, message in ipairs(messages) do
|
||||
if message.is_user_submission then
|
||||
message.provider = Config.provider
|
||||
message.model = Config.get_provider_config(Config.provider).model
|
||||
if not Config.acp_providers[Config.provider] then
|
||||
message.model = Config.get_provider_config(Config.provider).model
|
||||
end
|
||||
end
|
||||
local idx = nil
|
||||
for idx_, message_ in ipairs(history_messages) do
|
||||
@@ -2082,6 +2110,10 @@ function Sidebar:add_history_messages(messages)
|
||||
self.current_state = "generating"
|
||||
end
|
||||
end
|
||||
if opts and opts.eager_update then
|
||||
pcall(function() self:update_content("") end)
|
||||
return
|
||||
end
|
||||
xpcall(function() self:throttled_update_content("") end, function(err)
|
||||
Utils.debug("Failed to update content:", err)
|
||||
return nil
|
||||
@@ -2275,13 +2307,15 @@ function Sidebar:get_history_messages_for_api(opts)
|
||||
:totable()
|
||||
end
|
||||
|
||||
local tool_limit
|
||||
if Providers[Config.provider].use_ReAct_prompt then
|
||||
tool_limit = nil
|
||||
else
|
||||
tool_limit = 25
|
||||
if not Config.acp_providers[Config.provider] then
|
||||
local tool_limit
|
||||
if Providers[Config.provider].use_ReAct_prompt then
|
||||
tool_limit = nil
|
||||
else
|
||||
tool_limit = 25
|
||||
end
|
||||
messages = History.update_tool_invocation_history(messages, tool_limit, Config.behaviour.auto_check_diagnostics)
|
||||
end
|
||||
messages = History.update_tool_invocation_history(messages, tool_limit, Config.behaviour.auto_check_diagnostics)
|
||||
end
|
||||
|
||||
return messages
|
||||
@@ -2590,12 +2624,17 @@ function Sidebar:create_input_container()
|
||||
on_tool_log = on_tool_log,
|
||||
on_messages_add = on_messages_add,
|
||||
on_state_change = on_state_change,
|
||||
acp_client = self.acp_client,
|
||||
on_save_acp_client = function(client) self.acp_client = client end,
|
||||
acp_session_id = self.acp_session_id,
|
||||
on_save_acp_session_id = function(session_id) self.acp_session_id = session_id end,
|
||||
set_tool_use_store = set_tool_use_store,
|
||||
get_history_messages = function(opts) return self:get_history_messages_for_api(opts) end,
|
||||
get_todos = function()
|
||||
local history = Path.history.load(self.code.bufnr)
|
||||
return history and history.todos or {}
|
||||
end,
|
||||
update_todos = function(todos) self:update_todos(todos) end,
|
||||
session_ctx = {},
|
||||
---@param usage avante.LLMTokenUsage
|
||||
update_tokens_usage = function(usage)
|
||||
|
||||
@@ -98,6 +98,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field timestamp string
|
||||
---@field state avante.HistoryMessageState
|
||||
---@field uuid string | nil
|
||||
---@field displayed_tool_name string | nil
|
||||
---@field displayed_content string | nil
|
||||
---@field visible boolean | nil
|
||||
---@field is_context boolean | nil
|
||||
@@ -107,6 +108,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field selected_code AvanteSelectedCode | nil
|
||||
---@field selected_filepaths string[] | nil
|
||||
---@field tool_use_logs string[] | nil
|
||||
---@field tool_use_log_lines avante.ui.Line[] | nil
|
||||
---@field tool_use_store table | nil
|
||||
---@field just_for_display boolean | nil
|
||||
---@field is_dummy boolean | nil
|
||||
@@ -327,7 +329,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---
|
||||
---@class AvanteProviderFunctor
|
||||
---@field _model_list_cache table
|
||||
---@field extra_headers function(table) -> table | table | nil
|
||||
---@field extra_headers fun(table): table | table | nil
|
||||
---@field support_prompt_caching boolean | nil
|
||||
---@field role_map table<"user" | "assistant", string>
|
||||
---@field parse_messages AvanteMessagesParser
|
||||
@@ -359,6 +361,12 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field parse_response AvanteResponseParser
|
||||
---@field build_bedrock_payload AvanteBedrockPayloadBuilder
|
||||
---
|
||||
---@class AvanteACPProvider
|
||||
---@field command string
|
||||
---@field args string[]
|
||||
---@field env table<string, string>
|
||||
---@field auth_method string
|
||||
---
|
||||
---@alias AvanteLlmMode avante.Mode | "editing" | "suggesting"
|
||||
---
|
||||
---@class AvanteSelectedCode
|
||||
@@ -382,6 +390,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field diagnostics string | nil
|
||||
---@field history_messages avante.HistoryMessage[] | nil
|
||||
---@field get_todos? fun(): avante.TODO[]
|
||||
---@field update_todos? fun(todos: avante.TODO[]): nil
|
||||
---@field memory string | nil
|
||||
---@field get_tokens_usage? fun(): avante.LLMTokenUsage | nil
|
||||
---
|
||||
@@ -405,6 +414,10 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking" | "compacting" | "compacted" | "initializing" | "initialized"
|
||||
---
|
||||
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
|
||||
---@field acp_client? ACPClient
|
||||
---@field on_save_acp_client? fun(client: ACPClient): nil
|
||||
---@field acp_session_id? string
|
||||
---@field on_save_acp_session_id? fun(session_id: string): nil
|
||||
---@field on_start AvanteLLMStartCallback
|
||||
---@field on_chunk? AvanteLLMChunkCallback
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
|
||||
@@ -1700,4 +1700,12 @@ function M.fix_diff(diff)
|
||||
return table.concat(the_final_diff_lines, "\n")
|
||||
end
|
||||
|
||||
function M.get_unified_diff(text1, text2, opts)
|
||||
opts = opts or {}
|
||||
opts.result_type = "unified"
|
||||
opts.ctxlen = opts.ctxlen or 3
|
||||
|
||||
return vim.diff(text1, text2, opts)
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -26,7 +26,7 @@ function Tokens.calculate_tokens(content)
|
||||
elseif type(item) == "table" and item.type == "image" then
|
||||
text = text .. item.source.data
|
||||
elseif type(item) == "table" and item.type == "tool_result" then
|
||||
text = text .. item.content
|
||||
if type(item.content) == "string" then text = text .. item.content end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user