refactor: history messages (#1934)
This commit is contained in:
25
README.md
25
README.md
@@ -305,6 +305,8 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
||||
{
|
||||
---@alias Provider "claude" | "openai" | "azure" | "gemini" | "cohere" | "copilot" | string
|
||||
provider = "claude", -- The provider used in Aider mode or in the planning phase of Cursor Planning Mode
|
||||
---@alias Mode "agentic" | "legacy"
|
||||
mode = "agentic", -- The default mode for interaction. "agentic" uses tools to automatically generate code, "legacy" uses the old planning method to generate code.
|
||||
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
|
||||
-- currently designating it as `copilot` provider is dangerous because: https://github.com/yetone/avante.nvim/issues/1048
|
||||
-- Of course, you can reduce the request frequency by increasing `suggestion.debounce`.
|
||||
@@ -340,8 +342,6 @@ _See [config.lua#L9](./lua/avante/config.lua) for the full config_
|
||||
support_paste_from_clipboard = false,
|
||||
minimize_diff = true, -- Whether to remove unchanged lines when applying a code block
|
||||
enable_token_counting = true, -- Whether to enable token counting. Default to true.
|
||||
enable_cursor_planning_mode = false, -- Whether to enable Cursor Planning Mode. Default to false.
|
||||
enable_claude_text_editor_tool_mode = false, -- Whether to enable Claude Text Editor Tool Mode.
|
||||
},
|
||||
mappings = {
|
||||
--- @class AvanteConflictMappings
|
||||
@@ -734,12 +734,6 @@ Avante provides a set of default providers, but users can also create their own
|
||||
|
||||
For more information, see [Custom Providers](https://github.com/yetone/avante.nvim/wiki/Custom-providers)
|
||||
|
||||
## Cursor planning mode
|
||||
|
||||
Because avante.nvim has always used Aider’s method for planning applying, but its prompts are very picky with models and require ones like claude-3.5-sonnet or gpt-4o to work properly.
|
||||
|
||||
Therefore, I have adopted Cursor’s method to implement planning applying. For details on the implementation, please refer to [cursor-planning-mode.md](./cursor-planning-mode.md)
|
||||
|
||||
## RAG Service
|
||||
|
||||
Avante provides a RAG service, which is a tool for obtaining the required context for the AI to generate the codes. By default, it is not enabled. You can enable it this way:
|
||||
@@ -890,21 +884,6 @@ Avante allows you to define custom tools that can be used by the AI during code
|
||||
|
||||
Now you can integrate MCP functionality for Avante through `mcphub.nvim`. For detailed documentation, please refer to [mcphub.nvim](https://github.com/ravitemer/mcphub.nvim#avante-integration)
|
||||
|
||||
## Claude Text Editor Tool Mode
|
||||
|
||||
Avante leverages [Claude Text Editor Tool](https://docs.anthropic.com/en/docs/build-with-claude/tool-use/text-editor-tool) to provide a more elegant code editing experience. You can now enable this feature by setting `enable_claude_text_editor_tool_mode` to `true` in the `behaviour` configuration:
|
||||
|
||||
```lua
|
||||
{
|
||||
behaviour = {
|
||||
enable_claude_text_editor_tool_mode = true,
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> To enable **Claude Text Editor Tool Mode**, you must use the `claude-3-5-sonnet-*` or `claude-3-7-sonnet-*` model with the `claude` provider! This feature is not supported by any other models!
|
||||
|
||||
## Custom prompts
|
||||
|
||||
By default, `avante.nvim` provides three different modes to interact with: `planning`, `editing`, and `suggesting`, followed with three different prompts per mode.
|
||||
|
||||
@@ -226,16 +226,16 @@ end
|
||||
function M.select_model() require("avante.model_selector").open() end
|
||||
|
||||
function M.select_history()
|
||||
require("avante.history_selector").open(vim.api.nvim_get_current_buf(), function(filename)
|
||||
local Path = require("avante.path")
|
||||
Path.history.save_latest_filename(vim.api.nvim_get_current_buf(), filename)
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then
|
||||
require("avante.api").ask()
|
||||
sidebar = require("avante").get()
|
||||
end
|
||||
sidebar:update_content_with_history()
|
||||
if not sidebar:is_open() then sidebar:open({}) end
|
||||
local buf = vim.api.nvim_get_current_buf()
|
||||
require("avante.history_selector").open(buf, function(filename)
|
||||
vim.api.nvim_buf_call(buf, function()
|
||||
if not require("avante").is_sidebar_open() then require("avante").open_sidebar({}) end
|
||||
local Path = require("avante.path")
|
||||
Path.history.save_latest_filename(buf, filename)
|
||||
local sidebar = require("avante").get()
|
||||
sidebar:update_content_with_history()
|
||||
vim.schedule(function() sidebar:focus_input() end)
|
||||
end)
|
||||
end)
|
||||
end
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ local M = {}
|
||||
---@class avante.Config
|
||||
M._defaults = {
|
||||
debug = false,
|
||||
---@alias avante.Mode "agentic" | "legacy"
|
||||
mode = "agentic",
|
||||
---@alias avante.ProviderName "claude" | "openai" | "azure" | "gemini" | "vertex" | "cohere" | "copilot" | "bedrock" | "ollama" | string
|
||||
provider = "claude",
|
||||
-- WARNING: Since auto-suggestions are a high-frequency operation and therefore expensive,
|
||||
@@ -380,8 +382,6 @@ M._defaults = {
|
||||
support_paste_from_clipboard = false,
|
||||
minimize_diff = true,
|
||||
enable_token_counting = true,
|
||||
enable_cursor_planning_mode = false,
|
||||
enable_claude_text_editor_tool_mode = false,
|
||||
use_cwd_as_project_root = false,
|
||||
auto_focus_on_diff_view = false,
|
||||
},
|
||||
|
||||
@@ -40,6 +40,12 @@ local Highlights = {
|
||||
AVANTE_SIDEBAR_NORMAL = { name = "AvanteSidebarNormal", link = "NormalFloat" },
|
||||
AVANTE_COMMENT_FG = { name = "AvanteCommentFg", fg_link = "Comment" },
|
||||
AVANTE_REVERSED_NORMAL = { name = "AvanteReversedNormal", fg_link_bg = "Normal", bg_link_fg = "Normal" },
|
||||
AVANTE_STATE_SPINNER_GENERATING = { name = "AvanteStateSpinnerGenerating", fg = "#1e222a", bg = "#ab9df2" },
|
||||
AVANTE_STATE_SPINNER_TOOL_CALLING = { name = "AvanteStateSpinnerToolCalling", fg = "#1e222a", bg = "#56b6c2" },
|
||||
AVANTE_STATE_SPINNER_FAILED = { name = "AvanteStateSpinnerFailed", fg = "#1e222a", bg = "#e06c75" },
|
||||
AVANTE_STATE_SPINNER_SUCCEEDED = { name = "AvanteStateSpinnerSucceeded", fg = "#1e222a", bg = "#98c379" },
|
||||
AVANTE_STATE_SPINNER_SEARCHING = { name = "AvanteStateSpinnerSearching", fg = "#1e222a", bg = "#c678dd" },
|
||||
AVANTE_STATE_SPINNER_THINKING = { name = "AvanteStateSpinnerThinking", fg = "#1e222a", bg = "#c678dd" },
|
||||
}
|
||||
|
||||
Highlights.conflict = {
|
||||
|
||||
28
lua/avante/history_message.lua
Normal file
28
lua/avante/history_message.lua
Normal file
@@ -0,0 +1,28 @@
|
||||
local Utils = require("avante.utils")
|
||||
|
||||
---@class avante.HistoryMessage
|
||||
local M = {}
|
||||
M.__index = M
|
||||
|
||||
---@param message AvanteLLMMessage
|
||||
---@param opts? {is_user_submission?: boolean, visible?: boolean, displayed_content?: string, state?: avante.HistoryMessageState, uuid?: string, selected_filepaths?: string[], selected_code?: AvanteSelectedCode, just_for_display?: boolean}
|
||||
---@return avante.HistoryMessage
|
||||
function M:new(message, opts)
|
||||
opts = opts or {}
|
||||
local obj = setmetatable({}, M)
|
||||
obj.message = message
|
||||
obj.uuid = opts.uuid or Utils.uuid()
|
||||
obj.state = opts.state or "generated"
|
||||
obj.timestamp = Utils.get_timestamp()
|
||||
obj.is_user_submission = false
|
||||
obj.visible = true
|
||||
if opts.is_user_submission ~= nil then obj.is_user_submission = opts.is_user_submission end
|
||||
if opts.visible ~= nil then obj.visible = opts.visible end
|
||||
if opts.displayed_content ~= nil then obj.displayed_content = opts.displayed_content end
|
||||
if opts.selected_filepaths ~= nil then obj.selected_filepaths = opts.selected_filepaths end
|
||||
if opts.selected_code ~= nil then obj.selected_code = opts.selected_code end
|
||||
if opts.just_for_display ~= nil then obj.just_for_display = opts.just_for_display end
|
||||
return obj
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -9,8 +9,9 @@ local M = {}
|
||||
---@param history avante.ChatHistory
|
||||
---@return table?
|
||||
local function to_selector_item(history)
|
||||
local timestamp = #history.entries > 0 and history.entries[#history.entries].timestamp or history.timestamp
|
||||
local name = history.title .. " - " .. timestamp .. " (" .. #history.entries .. ")"
|
||||
local messages = Utils.get_history_messages(history)
|
||||
local timestamp = #messages > 0 and messages[#messages].timestamp or history.timestamp
|
||||
local name = history.title .. " - " .. timestamp .. " (" .. #messages .. ")"
|
||||
name = name:gsub("\n", "\\n")
|
||||
return {
|
||||
name = name,
|
||||
|
||||
@@ -10,6 +10,7 @@ local Path = require("avante.path")
|
||||
local Providers = require("avante.providers")
|
||||
local LLMToolHelpers = require("avante.llm_tools.helpers")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
|
||||
---@class avante.LLM
|
||||
local M = {}
|
||||
@@ -26,7 +27,7 @@ function M.summarize_chat_thread_title(content, cb)
|
||||
local system_prompt =
|
||||
[[Summarize the content as a title for the chat thread. The title should be a concise and informative summary of the conversation, capturing the main points and key takeaways. It should be no longer than 100 words and should be written in a clear and engaging style. The title should be suitable for use as the title of a chat thread on a messaging platform or other communication medium.]]
|
||||
local response_content = ""
|
||||
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||
local provider = Providers.get_memory_summary_provider()
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
@@ -58,73 +59,49 @@ function M.summarize_chat_thread_title(content, cb)
|
||||
})
|
||||
end
|
||||
|
||||
---@param messages AvanteLLMMessage[]
|
||||
---@return AvanteLLMMessage[]
|
||||
local function filter_out_tool_use_messages(messages)
|
||||
local filtered_messages = {}
|
||||
for _, message in ipairs(messages) do
|
||||
local content = message.content
|
||||
if type(content) == "table" then
|
||||
local new_content = {}
|
||||
for _, item in ipairs(content) do
|
||||
if item.type == "tool_use" or item.type == "tool_result" then goto continue end
|
||||
table.insert(new_content, item)
|
||||
::continue::
|
||||
end
|
||||
content = new_content
|
||||
end
|
||||
if type(content) == "table" then
|
||||
if #content > 0 then table.insert(filtered_messages, { role = message.role, content = content }) end
|
||||
else
|
||||
table.insert(filtered_messages, { role = message.role, content = content })
|
||||
end
|
||||
end
|
||||
return filtered_messages
|
||||
end
|
||||
|
||||
---@param bufnr integer
|
||||
---@param history avante.ChatHistory
|
||||
---@param entries? avante.ChatHistoryEntry[]
|
||||
---@param history_messages avante.HistoryMessage[]
|
||||
---@param cb fun(memory: avante.ChatMemory | nil): nil
|
||||
function M.summarize_memory(bufnr, history, entries, cb)
|
||||
local system_prompt = [[You are a helpful AI assistant tasked with summarizing conversations.]]
|
||||
if not entries then entries = Utils.history.filter_active_entries(history.entries) end
|
||||
if #entries == 0 then
|
||||
cb(nil)
|
||||
return
|
||||
end
|
||||
if history.memory then
|
||||
entries = vim
|
||||
.iter(entries)
|
||||
:filter(function(entry) return entry.timestamp > history.memory.last_summarized_timestamp end)
|
||||
:totable()
|
||||
end
|
||||
if #entries == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
local history_messages = Utils.history.entries_to_llm_messages(entries)
|
||||
history_messages = filter_out_tool_use_messages(history_messages)
|
||||
history_messages = vim.list_slice(history_messages, 1, 4)
|
||||
function M.summarize_memory(bufnr, history, history_messages, cb)
|
||||
local system_prompt =
|
||||
[[You are an expert coding assistant. Your goal is to generate a concise, structured summary of the conversation below that captures all essential information needed to continue development after context replacement. Include tasks performed, code areas modified or reviewed, key decisions or assumptions, test results or errors, and outstanding tasks or next steps.]]
|
||||
if #history_messages == 0 then
|
||||
cb(history.memory)
|
||||
return
|
||||
end
|
||||
Utils.debug("summarize memory", #history_messages, history_messages[#history_messages].content)
|
||||
local user_prompt =
|
||||
[[Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.]]
|
||||
local latest_timestamp = history_messages[#history_messages].timestamp
|
||||
local latest_message_uuid = history_messages[#history_messages].uuid
|
||||
local conversation_items = vim
|
||||
.iter(history_messages)
|
||||
:filter(function(msg)
|
||||
if msg.just_for_display then return false end
|
||||
if msg.message.role ~= "assistant" and msg.message.role ~= "user" then return false end
|
||||
local content = msg.message.content
|
||||
if type(content) == "table" and content[1].type == "tool_result" then return false end
|
||||
if type(content) == "table" and content[1].type == "tool_use" then return false end
|
||||
return true
|
||||
end)
|
||||
:map(function(msg) return msg.message.role .. ": " .. Utils.message_to_text(msg) end)
|
||||
:totable()
|
||||
local conversation_text = table.concat(conversation_items, "\n")
|
||||
local user_prompt = "Here is the conversation so far:\n"
|
||||
.. conversation_text
|
||||
.. "\n\nPlease summarize this conversation, covering:\n1. Tasks performed and outcomes\n2. Code files, modules, or functions modified or examined\n3. Important decisions or assumptions made\n4. Errors encountered and test or build results\n5. Remaining tasks, open questions, or next steps\nProvide the summary in a clear, concise format."
|
||||
if history.memory then user_prompt = user_prompt .. "\n\nThe previous summary is:\n\n" .. history.memory.content end
|
||||
table.insert(history_messages, {
|
||||
role = "user",
|
||||
content = user_prompt,
|
||||
})
|
||||
local messages = {
|
||||
{
|
||||
role = "user",
|
||||
content = user_prompt,
|
||||
},
|
||||
}
|
||||
local response_content = ""
|
||||
local provider = Providers[Config.memory_summary_provider or Config.provider]
|
||||
local provider = Providers.get_memory_summary_provider()
|
||||
M.curl({
|
||||
provider = provider,
|
||||
prompt_opts = {
|
||||
system_prompt = system_prompt,
|
||||
messages = history_messages,
|
||||
messages = messages,
|
||||
},
|
||||
handler_opts = {
|
||||
on_start = function(_) end,
|
||||
@@ -141,11 +118,14 @@ function M.summarize_memory(bufnr, history, entries, cb)
|
||||
response_content = Utils.trim_think_content(response_content)
|
||||
local memory = {
|
||||
content = response_content,
|
||||
last_summarized_timestamp = entries[#entries].timestamp,
|
||||
last_summarized_timestamp = latest_timestamp,
|
||||
last_message_uuid = latest_message_uuid,
|
||||
}
|
||||
history.memory = memory
|
||||
Path.history.save(bufnr, history)
|
||||
cb(memory)
|
||||
else
|
||||
cb(history.memory)
|
||||
end
|
||||
end,
|
||||
},
|
||||
@@ -156,7 +136,7 @@ end
|
||||
---@return AvantePromptOptions
|
||||
function M.generate_prompts(opts)
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
local mode = opts.mode or "planning"
|
||||
local mode = opts.mode or Config.mode
|
||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
local _, request_body = Providers.parse_config(provider)
|
||||
local max_tokens = request_body.max_tokens or 4096
|
||||
@@ -166,22 +146,58 @@ function M.generate_prompts(opts)
|
||||
if opts.prompt_opts and opts.prompt_opts.image_paths then
|
||||
image_paths = vim.list_extend(image_paths, opts.prompt_opts.image_paths)
|
||||
end
|
||||
local instructions = opts.instructions
|
||||
if instructions and instructions:match("image: ") then
|
||||
local lines = vim.split(opts.instructions, "\n")
|
||||
for i, line in ipairs(lines) do
|
||||
if line:match("^image: ") then
|
||||
local image_path = line:gsub("^image: ", "")
|
||||
table.insert(image_paths, image_path)
|
||||
table.remove(lines, i)
|
||||
end
|
||||
end
|
||||
instructions = table.concat(lines, "\n")
|
||||
end
|
||||
|
||||
local project_root = Utils.root.get()
|
||||
Path.prompts.initialize(Path.prompts.get_templates_dir(project_root))
|
||||
|
||||
local tool_id_to_tool_name = {}
|
||||
local tool_id_to_path = {}
|
||||
local viewed_files = {}
|
||||
if opts.history_messages then
|
||||
for _, message in ipairs(opts.history_messages) do
|
||||
local content = message.message.content
|
||||
if type(content) ~= "table" then goto continue end
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) ~= "table" then goto continue1 end
|
||||
if item.type ~= "tool_use" then goto continue1 end
|
||||
local tool_name = item.name
|
||||
if tool_name ~= "view" then goto continue1 end
|
||||
local path = item.input.path
|
||||
tool_id_to_tool_name[item.id] = tool_name
|
||||
if path then
|
||||
local uniform_path = Utils.uniform_path(path)
|
||||
tool_id_to_path[item.id] = uniform_path
|
||||
viewed_files[uniform_path] = item.id
|
||||
end
|
||||
::continue1::
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
for _, message in ipairs(opts.history_messages) do
|
||||
local content = message.message.content
|
||||
if type(content) == "table" then
|
||||
for _, item in ipairs(content) do
|
||||
if type(item) ~= "table" then goto continue end
|
||||
if item.type ~= "tool_result" then goto continue end
|
||||
local tool_name = tool_id_to_tool_name[item.tool_use_id]
|
||||
if tool_name ~= "view" then goto continue end
|
||||
local path = tool_id_to_path[item.tool_use_id]
|
||||
local latest_tool_id = viewed_files[path]
|
||||
if not latest_tool_id then goto continue end
|
||||
if latest_tool_id ~= item.tool_use_id then
|
||||
item.content = string.format("The file %s has been updated. Please use the latest view tool result!", path)
|
||||
else
|
||||
local lines, error = Utils.read_file_from_buf_or_disk(path)
|
||||
if error ~= nil then Utils.error("error reading file: " .. error) end
|
||||
lines = lines or {}
|
||||
item.content = table.concat(lines, "\n")
|
||||
end
|
||||
::continue::
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local system_info = Utils.get_system_info()
|
||||
|
||||
local selected_files = opts.selected_files or {}
|
||||
@@ -200,6 +216,8 @@ function M.generate_prompts(opts)
|
||||
end
|
||||
end
|
||||
|
||||
selected_files = vim.iter(selected_files):filter(function(file) return viewed_files[file.path] == nil end):totable()
|
||||
|
||||
local template_opts = {
|
||||
ask = opts.ask, -- TODO: add mode without ask instruction
|
||||
code_lang = opts.code_lang,
|
||||
@@ -229,36 +247,42 @@ function M.generate_prompts(opts)
|
||||
end
|
||||
|
||||
---@type AvanteLLMMessage[]
|
||||
local messages = {}
|
||||
local context_messages = {}
|
||||
if opts.prompt_opts and opts.prompt_opts.messages then
|
||||
messages = vim.list_extend(messages, opts.prompt_opts.messages)
|
||||
context_messages = vim.list_extend(context_messages, opts.prompt_opts.messages)
|
||||
end
|
||||
|
||||
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
|
||||
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
|
||||
if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
|
||||
if project_context ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = project_context, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
|
||||
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
|
||||
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
|
||||
if diagnostics ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = diagnostics, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if #selected_files > 0 or opts.selected_code ~= nil then
|
||||
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
|
||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||
if code_context ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = code_context, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if opts.memory ~= nil and opts.memory ~= "" and opts.memory ~= "null" then
|
||||
local memory = Path.prompts.render_file("_memory.avanterules", template_opts)
|
||||
if memory ~= "" then table.insert(messages, { role = "user", content = memory }) end
|
||||
if memory ~= "" then
|
||||
table.insert(context_messages, { role = "user", content = memory, visible = false, is_context = true })
|
||||
end
|
||||
end
|
||||
|
||||
if instructions then table.insert(messages, { role = "user", content = instructions }) end
|
||||
|
||||
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
|
||||
|
||||
for _, message in ipairs(messages) do
|
||||
for _, message in ipairs(context_messages) do
|
||||
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
|
||||
end
|
||||
|
||||
@@ -267,47 +291,49 @@ function M.generate_prompts(opts)
|
||||
dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages)
|
||||
end
|
||||
|
||||
local final_history_messages = {}
|
||||
if opts.history_messages then
|
||||
if Config.history.max_tokens > 0 then remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens) end
|
||||
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
|
||||
local history_messages = {}
|
||||
for i = #opts.history_messages, 1, -1 do
|
||||
local message = opts.history_messages[i]
|
||||
if Config.history.carried_entry_count ~= nil then
|
||||
if #history_messages > Config.history.carried_entry_count then break end
|
||||
table.insert(history_messages, message)
|
||||
local tokens = Utils.tokens.calculate_tokens(message.message.content)
|
||||
remaining_tokens = remaining_tokens - tokens
|
||||
if remaining_tokens > 0 then
|
||||
table.insert(history_messages, 1, message)
|
||||
else
|
||||
local tokens = Utils.tokens.calculate_tokens(message.content)
|
||||
remaining_tokens = remaining_tokens - tokens
|
||||
if remaining_tokens > 0 then
|
||||
table.insert(history_messages, message)
|
||||
else
|
||||
break
|
||||
end
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if #history_messages == 0 then
|
||||
history_messages = vim.list_slice(opts.history_messages, #opts.history_messages - 1, #opts.history_messages)
|
||||
end
|
||||
|
||||
dropped_history_messages = vim.list_slice(opts.history_messages, 1, #opts.history_messages - #history_messages)
|
||||
|
||||
-- prepend the history messages to the messages table
|
||||
vim.iter(history_messages):each(function(msg) table.insert(messages, 1, msg) end)
|
||||
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
|
||||
vim.iter(history_messages):each(function(msg) table.insert(final_history_messages, msg) end)
|
||||
end
|
||||
|
||||
if opts.mode == "cursor-applying" then
|
||||
local user_prompt = [[
|
||||
Merge all changes from the <update> snippet into the <code> below.
|
||||
- Preserve the code's structure, order, comments, and indentation exactly.
|
||||
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
|
||||
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
|
||||
-- Utils.debug("opts.history_messages", opts.history_messages)
|
||||
-- Utils.debug("final_history_messages", final_history_messages)
|
||||
|
||||
]]
|
||||
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code)
|
||||
for _, snippet in ipairs(opts.update_snippets) do
|
||||
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet)
|
||||
end
|
||||
user_prompt = user_prompt .. "Provide the complete updated code."
|
||||
table.insert(messages, { role = "user", content = user_prompt })
|
||||
---@type AvanteLLMMessage[]
|
||||
local messages = vim.deepcopy(context_messages)
|
||||
for _, msg in ipairs(final_history_messages) do
|
||||
local message = msg.message
|
||||
table.insert(messages, message)
|
||||
end
|
||||
|
||||
messages = vim
|
||||
.iter(messages)
|
||||
:filter(function(msg) return type(msg.content) ~= "string" or msg.content ~= "" end)
|
||||
:totable()
|
||||
|
||||
if opts.instructions ~= nil and opts.instructions ~= "" then
|
||||
messages = vim.list_extend(messages, { { role = "user", content = opts.instructions } })
|
||||
end
|
||||
|
||||
opts.session_ctx = opts.session_ctx or {}
|
||||
@@ -318,19 +344,12 @@ Merge all changes from the <update> snippet into the <code> below.
|
||||
if opts.tools then tools = vim.list_extend(tools, opts.tools) end
|
||||
if opts.prompt_opts and opts.prompt_opts.tools then tools = vim.list_extend(tools, opts.prompt_opts.tools) end
|
||||
|
||||
local tool_histories = {}
|
||||
if opts.tool_histories then tool_histories = vim.list_extend(tool_histories, opts.tool_histories) end
|
||||
if opts.prompt_opts and opts.prompt_opts.tool_histories then
|
||||
tool_histories = vim.list_extend(tool_histories, opts.prompt_opts.tool_histories)
|
||||
end
|
||||
|
||||
---@type AvantePromptOptions
|
||||
return {
|
||||
system_prompt = system_prompt,
|
||||
messages = messages,
|
||||
image_paths = image_paths,
|
||||
tools = tools,
|
||||
tool_histories = tool_histories,
|
||||
dropped_history_messages = dropped_history_messages,
|
||||
}
|
||||
end
|
||||
@@ -372,7 +391,9 @@ function M.curl(opts)
|
||||
---@type string
|
||||
local current_event_state = nil
|
||||
local resp_ctx = {}
|
||||
resp_ctx.session_id = Utils.uuid()
|
||||
|
||||
local response_body = ""
|
||||
---@param line string
|
||||
local function parse_stream_data(line)
|
||||
local event = line:match("^event:%s*(.+)$")
|
||||
@@ -381,7 +402,16 @@ function M.curl(opts)
|
||||
return
|
||||
end
|
||||
local data_match = line:match("^data:%s*(.+)$")
|
||||
if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
|
||||
if data_match then
|
||||
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
|
||||
else
|
||||
response_body = response_body .. line
|
||||
local ok, jsn = pcall(vim.json.decode, response_body)
|
||||
if ok then
|
||||
response_body = ""
|
||||
if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function parse_response_without_stream(data)
|
||||
@@ -394,10 +424,13 @@ function M.curl(opts)
|
||||
|
||||
local temp_file = fn.tempname()
|
||||
local curl_body_file = temp_file .. "-request-body.json"
|
||||
local resp_body_file = temp_file .. "-response-body.json"
|
||||
local json_content = vim.json.encode(spec.body)
|
||||
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
||||
|
||||
Utils.debug("curl body file:", curl_body_file)
|
||||
Utils.debug("curl request body file:", curl_body_file)
|
||||
|
||||
Utils.debug("curl response body file:", resp_body_file)
|
||||
|
||||
local headers_file = temp_file .. "-headers.txt"
|
||||
|
||||
@@ -407,6 +440,7 @@ function M.curl(opts)
|
||||
if Config.debug then return end
|
||||
vim.schedule(function()
|
||||
fn.delete(curl_body_file)
|
||||
pcall(fn.delete, resp_body_file)
|
||||
fn.delete(headers_file)
|
||||
end)
|
||||
end
|
||||
@@ -431,6 +465,15 @@ function M.curl(opts)
|
||||
return
|
||||
end
|
||||
if not data then return end
|
||||
if Config.debug then
|
||||
if type(data) == "string" then
|
||||
local file = io.open(resp_body_file, "a")
|
||||
if file then
|
||||
file:write(data .. "\n")
|
||||
file:close()
|
||||
end
|
||||
end
|
||||
end
|
||||
vim.schedule(function()
|
||||
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
|
||||
if provider.parse_response ~= nil then
|
||||
@@ -495,6 +538,17 @@ function M.curl(opts)
|
||||
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
||||
end
|
||||
if result.status == 429 then
|
||||
Utils.debug("result", result)
|
||||
if result.body then
|
||||
local ok, jsn = pcall(vim.json.decode, result.body)
|
||||
if ok then
|
||||
if jsn.error and jsn.error.message then
|
||||
handler_opts.on_stop({ reason = "error", error = jsn.error.message })
|
||||
return
|
||||
end
|
||||
end
|
||||
end
|
||||
Utils.debug("result", result)
|
||||
local retry_after = 10
|
||||
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
||||
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
||||
@@ -585,17 +639,34 @@ function M._stream(opts)
|
||||
|
||||
---@type AvanteHandlerOptions
|
||||
local handler_opts = {
|
||||
on_partial_tool_use = opts.on_partial_tool_use,
|
||||
on_messages_add = opts.on_messages_add,
|
||||
on_state_change = opts.on_state_change,
|
||||
on_start = opts.on_start,
|
||||
on_chunk = opts.on_chunk,
|
||||
on_stop = function(stop_opts)
|
||||
---@param tool_use_list AvanteLLMToolUse[]
|
||||
---@param tool_use_index integer
|
||||
---@param tool_histories AvanteLLMToolHistory[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_histories)
|
||||
---@param tool_results AvanteLLMToolResult[]
|
||||
local function handle_next_tool_use(tool_use_list, tool_use_index, tool_results)
|
||||
if tool_use_index > #tool_use_list then
|
||||
---@type avante.HistoryMessage[]
|
||||
local messages = {}
|
||||
for _, tool_result in ipairs(tool_results) do
|
||||
messages[#messages + 1] = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = tool_result.tool_use_id,
|
||||
content = tool_result.content,
|
||||
is_error = tool_result.is_error,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
opts.on_messages_add(messages)
|
||||
local new_opts = vim.tbl_deep_extend("force", opts, {
|
||||
tool_histories = tool_histories,
|
||||
history_messages = opts.get_history_messages(),
|
||||
})
|
||||
if provider.get_rate_limit_sleep_time then
|
||||
local sleep_time = provider:get_rate_limit_sleep_time(resp_headers)
|
||||
@@ -616,7 +687,7 @@ function M._stream(opts)
|
||||
if error == LLMToolHelpers.CANCEL_TOKEN then
|
||||
Utils.debug("Tool execution was cancelled by user")
|
||||
opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n")
|
||||
return opts.on_stop({ reason = "cancelled", tool_histories = tool_histories })
|
||||
return opts.on_stop({ reason = "cancelled" })
|
||||
end
|
||||
|
||||
local tool_result = {
|
||||
@@ -624,8 +695,8 @@ function M._stream(opts)
|
||||
content = error ~= nil and error or result,
|
||||
is_error = error ~= nil,
|
||||
}
|
||||
table.insert(tool_histories, { tool_result = tool_result, tool_use = tool_use })
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_histories)
|
||||
table.insert(tool_results, tool_result)
|
||||
return handle_next_tool_use(tool_use_list, tool_use_index + 1, tool_results)
|
||||
end
|
||||
-- Either on_complete handles the tool result asynchronously or we receive the result and error synchronously when either is not nil
|
||||
local result, error = LLMTools.process_tool_use(
|
||||
@@ -638,20 +709,53 @@ function M._stream(opts)
|
||||
if result ~= nil or error ~= nil then return handle_tool_result(result, error) end
|
||||
end
|
||||
if stop_opts.reason == "cancelled" then
|
||||
opts.on_chunk("\n*[Request cancelled by user.]*\n")
|
||||
return opts.on_stop({ reason = "cancelled", tool_histories = opts.tool_histories })
|
||||
if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end
|
||||
if opts.on_messages_add then
|
||||
local message = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = "[Request cancelled by user.]",
|
||||
})
|
||||
opts.on_messages_add({ message })
|
||||
end
|
||||
return opts.on_stop({ reason = "cancelled" })
|
||||
end
|
||||
if stop_opts.reason == "tool_use" and stop_opts.tool_use_list then
|
||||
local old_tool_histories = vim.deepcopy(opts.tool_histories) or {}
|
||||
if stop_opts.reason == "tool_use" then
|
||||
local tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
local tool_result_seen = {}
|
||||
local history_messages = opts.get_history_messages and opts.get_history_messages() or {}
|
||||
for idx = #history_messages, 1, -1 do
|
||||
local message = history_messages[idx]
|
||||
local content = message.message.content
|
||||
if type(content) ~= "table" or #content == 0 then goto continue end
|
||||
if content[1].type == "tool_use" then
|
||||
if not tool_result_seen[content[1].id] then
|
||||
table.insert(tool_use_list, 1, content[1])
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
if content[1].type == "tool_result" then tool_result_seen[content[1].tool_use_id] = true end
|
||||
::continue::
|
||||
end
|
||||
local sorted_tool_use_list = {} ---@type AvanteLLMToolUse[]
|
||||
for _, tool_use in vim.spairs(stop_opts.tool_use_list) do
|
||||
for _, tool_use in vim.spairs(tool_use_list) do
|
||||
table.insert(sorted_tool_use_list, tool_use)
|
||||
end
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, old_tool_histories)
|
||||
return handle_next_tool_use(sorted_tool_use_list, 1, {})
|
||||
end
|
||||
if stop_opts.reason == "rate_limit" then
|
||||
local msg = "Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ..."
|
||||
opts.on_chunk("\n*[" .. msg .. "]*\n")
|
||||
local msg_content = "*[Rate limit reached. Retrying in " .. stop_opts.retry_after .. " seconds ...]*"
|
||||
if opts.on_chunk then opts.on_chunk("\n" .. msg_content .. "\n") end
|
||||
local message
|
||||
if opts.on_messages_add then
|
||||
message = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = "\n\n" .. msg_content,
|
||||
}, {
|
||||
just_for_display = true,
|
||||
})
|
||||
opts.on_messages_add({ message })
|
||||
end
|
||||
local timer = vim.loop.new_timer()
|
||||
if timer then
|
||||
local retry_after = stop_opts.retry_after
|
||||
@@ -661,8 +765,12 @@ function M._stream(opts)
|
||||
0,
|
||||
vim.schedule_wrap(function()
|
||||
if retry_after > 0 then retry_after = retry_after - 1 end
|
||||
local msg_ = "Rate limit reached. Retrying in " .. retry_after .. " seconds ..."
|
||||
opts.on_chunk([[\033[1A\033[K]] .. "\n*[" .. msg_ .. "]*\n")
|
||||
local msg_content_ = "*[Rate limit reached. Retrying in " .. retry_after .. " seconds ...]*"
|
||||
if opts.on_chunk then opts.on_chunk([[\033[1A\033[K]] .. "\n" .. msg_content_ .. "\n") end
|
||||
if opts.on_messages_add and message then
|
||||
message.message.content = "\n\n" .. msg_content_
|
||||
opts.on_messages_add({ message })
|
||||
end
|
||||
countdown()
|
||||
end)
|
||||
)
|
||||
@@ -676,7 +784,6 @@ function M._stream(opts)
|
||||
end, stop_opts.retry_after * 1000)
|
||||
return
|
||||
end
|
||||
stop_opts.tool_histories = opts.tool_histories
|
||||
return opts.on_stop(stop_opts)
|
||||
end,
|
||||
}
|
||||
@@ -697,6 +804,8 @@ local function _merge_response(first_response, second_response, opts)
|
||||
|
||||
prompt = prompt .. "\n"
|
||||
|
||||
if opts.instructions == nil then opts.instructions = "" end
|
||||
|
||||
-- append this reference prompt to the prompt_opts messages at last
|
||||
opts.instructions = opts.instructions .. prompt
|
||||
|
||||
@@ -802,20 +911,12 @@ function M.stream(opts)
|
||||
return original_on_stop(stop_opts)
|
||||
end)
|
||||
end
|
||||
if opts.on_partial_tool_use ~= nil then
|
||||
local original_on_partial_tool_use = opts.on_partial_tool_use
|
||||
opts.on_partial_tool_use = vim.schedule_wrap(function(tool_use)
|
||||
if is_completed then return end
|
||||
return original_on_partial_tool_use(tool_use)
|
||||
end)
|
||||
end
|
||||
|
||||
local valid_dual_boost_modes = {
|
||||
planning = true,
|
||||
["cursor-planning"] = true,
|
||||
legacy = true,
|
||||
}
|
||||
|
||||
opts.mode = opts.mode or "planning"
|
||||
opts.mode = opts.mode or Config.mode
|
||||
|
||||
if Config.dual_boost.enabled and valid_dual_boost_modes[opts.mode] then
|
||||
M._dual_boost_stream(
|
||||
|
||||
@@ -10,7 +10,7 @@ M.name = "create"
|
||||
|
||||
M.description = "The create tool allows you to create a new file with specified content."
|
||||
|
||||
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end
|
||||
function M.enabled() return require("avante.config").mode == "agentic" end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
|
||||
@@ -71,7 +71,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
local prompt = opts.prompt
|
||||
local tools = get_available_tools()
|
||||
local start_time = os.date("%Y-%m-%d %H:%M:%S")
|
||||
local start_time = Utils.get_timestamp()
|
||||
|
||||
if on_log then on_log("prompt: " .. prompt) end
|
||||
|
||||
@@ -84,6 +84,8 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
messages = messages or {}
|
||||
table.insert(messages, { role = "user", content = prompt })
|
||||
|
||||
local tool_use_messages = {}
|
||||
|
||||
local total_tokens = 0
|
||||
local final_response = ""
|
||||
Llm._stream({
|
||||
@@ -93,6 +95,15 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
on_tool_log = function(tool_id, tool_name, log, state)
|
||||
if on_log then on_log(string.format("[%s] %s", tool_name, log)) end
|
||||
end,
|
||||
on_messages_add = function(msgs)
|
||||
msgs = vim.is_list(msgs) and msgs or { msgs }
|
||||
for _, msg in ipairs(msgs) do
|
||||
local content = msg.message.content
|
||||
if type(content) == "table" and #content > 0 and content[1].type == "tool_use" then
|
||||
tool_use_messages[msg.uuid] = true
|
||||
end
|
||||
end
|
||||
end,
|
||||
session_ctx = session_ctx,
|
||||
prompt_opts = {
|
||||
system_prompt = system_prompt,
|
||||
@@ -111,9 +122,9 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
on_complete(err, nil)
|
||||
return
|
||||
end
|
||||
local end_time = os.date("%Y-%m-%d %H:%M:%S")
|
||||
local elapsed_time = Utils.datetime_diff(tostring(start_time), tostring(end_time))
|
||||
local tool_use_count = stop_opts.tool_histories and #stop_opts.tool_histories or 0
|
||||
local end_time = Utils.get_timestamp()
|
||||
local elapsed_time = Utils.datetime_diff(start_time, end_time)
|
||||
local tool_use_count = vim.tbl_count(tool_use_messages)
|
||||
local summary = "Done ("
|
||||
.. (tool_use_count <= 1 and "1 tool use" or tool_use_count .. " tool uses")
|
||||
.. " · "
|
||||
|
||||
@@ -598,6 +598,7 @@ end
|
||||
|
||||
---@type AvanteLLMTool[]
|
||||
M._tools = {
|
||||
require("avante.llm_tools.replace_in_file"),
|
||||
require("avante.llm_tools.dispatch_agent"),
|
||||
require("avante.llm_tools.glob"),
|
||||
{
|
||||
@@ -1104,7 +1105,7 @@ M._tools = {
|
||||
---@return string | nil result
|
||||
---@return string | nil error
|
||||
function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
|
||||
Utils.debug("use tool", tool_use.name, tool_use.input_json)
|
||||
-- Utils.debug("use tool", tool_use.name, tool_use.input_json)
|
||||
|
||||
-- Check if execution is already cancelled
|
||||
if Helpers.is_cancelled then
|
||||
@@ -1125,8 +1126,7 @@ function M.process_tool_use(tools, tool_use, on_log, on_complete, session_ctx)
|
||||
if tool == nil then return nil, "This tool is not provided: " .. tool_use.name end
|
||||
func = tool.func or M[tool.name]
|
||||
end
|
||||
local ok, input_json = pcall(vim.json.decode, tool_use.input_json)
|
||||
if not ok then return nil, "Failed to decode tool input json: " .. vim.inspect(input_json) end
|
||||
local input_json = tool_use.input
|
||||
if not func then return nil, "Tool not found: " .. tool_use.name end
|
||||
if on_log then on_log(tool_use.id, tool_use.name, "running tool", "running") end
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ M.name = "insert"
|
||||
|
||||
M.description = "The insert tool allows you to insert text at a specific location in a file."
|
||||
|
||||
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end
|
||||
function M.enabled() return require("avante.config").mode == "agentic" end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
|
||||
481
lua/avante/llm_tools/replace_in_file.lua
Normal file
481
lua/avante/llm_tools/replace_in_file.lua
Normal file
@@ -0,0 +1,481 @@
|
||||
local Base = require("avante.llm_tools.base")
|
||||
local Helpers = require("avante.llm_tools.helpers")
|
||||
local Utils = require("avante.utils")
|
||||
local Highlights = require("avante.highlights")
|
||||
local Config = require("avante.config")
|
||||
|
||||
local PRIORITY = (vim.hl or vim.highlight).priorities.user
|
||||
local NAMESPACE = vim.api.nvim_create_namespace("avante-diff")
|
||||
local KEYBINDING_NAMESPACE = vim.api.nvim_create_namespace("avante-diff-keybinding")
|
||||
|
||||
---@class AvanteLLMTool
|
||||
local M = setmetatable({}, Base)
|
||||
|
||||
M.name = "replace_in_file"
|
||||
|
||||
M.description =
|
||||
"Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file."
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
type = "table",
|
||||
fields = {
|
||||
{
|
||||
name = "path",
|
||||
description = "The path to the file in the current project scope",
|
||||
type = "string",
|
||||
},
|
||||
{
|
||||
name = "diff",
|
||||
description = [[
|
||||
One or more SEARCH/REPLACE blocks following this exact format:
|
||||
\`\`\`
|
||||
<<<<<<< SEARCH
|
||||
[exact content to find]
|
||||
=======
|
||||
[new content to replace with]
|
||||
>>>>>>> REPLACE
|
||||
\`\`\`
|
||||
Critical rules:
|
||||
1. SEARCH content must match the associated file section to find EXACTLY:
|
||||
* Match character-for-character including whitespace, indentation, line endings
|
||||
* Include all comments, docstrings, etc.
|
||||
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
|
||||
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
|
||||
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
|
||||
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
|
||||
3. Keep SEARCH/REPLACE blocks concise:
|
||||
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
|
||||
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
|
||||
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
|
||||
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
|
||||
4. Special operations:
|
||||
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
|
||||
* To delete code: Use empty REPLACE section
|
||||
]],
|
||||
type = "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolReturn[]
|
||||
M.returns = {
|
||||
{
|
||||
name = "success",
|
||||
description = "True if the replacement was successful, false otherwise",
|
||||
type = "boolean",
|
||||
},
|
||||
{
|
||||
name = "error",
|
||||
description = "Error message if the replacement failed",
|
||||
type = "string",
|
||||
optional = true,
|
||||
},
|
||||
}
|
||||
|
||||
---@type AvanteLLMToolFunc<{ path: string, diff: string }>
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not opts.path or not opts.diff then return false, "path and diff are required" end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
|
||||
local diff_lines = vim.split(opts.diff, "\n")
|
||||
local is_searching = false
|
||||
local is_replacing = false
|
||||
local current_search = {}
|
||||
local current_replace = {}
|
||||
local rough_diff_blocks = {}
|
||||
|
||||
for _, line in ipairs(diff_lines) do
|
||||
if line:match("^%s*<<<<<<< SEARCH") then
|
||||
is_searching = true
|
||||
is_replacing = false
|
||||
current_search = {}
|
||||
elseif line:match("^%s*=======") and is_searching then
|
||||
is_searching = false
|
||||
is_replacing = true
|
||||
current_replace = {}
|
||||
elseif line:match("^%s*>>>>>>> REPLACE") and is_replacing then
|
||||
is_replacing = false
|
||||
table.insert(
|
||||
rough_diff_blocks,
|
||||
{ search = table.concat(current_search, "\n"), replace = table.concat(current_replace, "\n") }
|
||||
)
|
||||
elseif is_searching then
|
||||
table.insert(current_search, line)
|
||||
elseif is_replacing then
|
||||
table.insert(current_replace, line)
|
||||
end
|
||||
end
|
||||
|
||||
if #rough_diff_blocks == 0 then return false, "No diff blocks found" end
|
||||
|
||||
local bufnr, err = Helpers.get_bufnr(abs_path)
|
||||
if err then return false, err end
|
||||
local original_lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then return false, "Avante sidebar not found" end
|
||||
|
||||
local function parse_rough_diff_block(rough_diff_block, current_lines)
|
||||
local old_lines = vim.split(rough_diff_block.search, "\n")
|
||||
local new_lines = vim.split(rough_diff_block.replace, "\n")
|
||||
local start_line, end_line
|
||||
for i = 1, #current_lines - #old_lines + 1 do
|
||||
local match = true
|
||||
for j = 1, #old_lines do
|
||||
if Utils.remove_indentation(current_lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
end
|
||||
if match then
|
||||
start_line = i
|
||||
end_line = i + #old_lines - 1
|
||||
break
|
||||
end
|
||||
end
|
||||
if start_line == nil or end_line == nil then
|
||||
return "Failed to find the old string:\n" .. rough_diff_block.search
|
||||
end
|
||||
local old_str = rough_diff_block.search
|
||||
local new_str = rough_diff_block.replace
|
||||
local original_indentation = Utils.get_indentation(current_lines[start_line])
|
||||
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
|
||||
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
|
||||
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
|
||||
old_str = table.concat(old_lines, "\n")
|
||||
new_str = table.concat(new_lines, "\n")
|
||||
end
|
||||
rough_diff_block.old_lines = old_lines
|
||||
rough_diff_block.new_lines = new_lines
|
||||
rough_diff_block.search = old_str
|
||||
rough_diff_block.replace = new_str
|
||||
rough_diff_block.start_line = start_line
|
||||
rough_diff_block.end_line = end_line
|
||||
return nil
|
||||
end
|
||||
|
||||
local function rough_diff_blocks_to_diff_blocks(rough_diff_blocks_)
|
||||
local res = {}
|
||||
local base_line_ = 0
|
||||
for _, rough_diff_block in ipairs(rough_diff_blocks_) do
|
||||
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
|
||||
local patch = vim.diff(rough_diff_block.search, rough_diff_block.replace, { ---@type integer[][]
|
||||
algorithm = "histogram",
|
||||
result_type = "indices",
|
||||
ctxlen = vim.o.scrolloff,
|
||||
})
|
||||
for _, hunk in ipairs(patch) do
|
||||
local start_a, count_a, start_b, count_b = unpack(hunk)
|
||||
local diff_block = {}
|
||||
if count_a > 0 then
|
||||
diff_block.old_lines = vim.list_slice(rough_diff_block.old_lines, start_a, start_a + count_a - 1)
|
||||
else
|
||||
diff_block.old_lines = {}
|
||||
end
|
||||
if count_b > 0 then
|
||||
diff_block.new_lines = vim.list_slice(rough_diff_block.new_lines, start_b, start_b + count_b - 1)
|
||||
else
|
||||
diff_block.new_lines = {}
|
||||
end
|
||||
if count_a > 0 then
|
||||
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a - 1
|
||||
else
|
||||
diff_block.start_line = base_line_ + rough_diff_block.start_line + start_a
|
||||
end
|
||||
diff_block.end_line = base_line_ + rough_diff_block.start_line + start_a + math.max(count_a, 1) - 2
|
||||
diff_block.search = table.concat(diff_block.old_lines, "\n")
|
||||
diff_block.replace = table.concat(diff_block.new_lines, "\n")
|
||||
table.insert(res, diff_block)
|
||||
end
|
||||
|
||||
local distance = 0
|
||||
for _, hunk in ipairs(patch) do
|
||||
local _, count_a, _, count_b = unpack(hunk)
|
||||
distance = distance + count_b - count_a
|
||||
end
|
||||
|
||||
local old_distance = #rough_diff_block.new_lines - #rough_diff_block.old_lines
|
||||
|
||||
base_line_ = base_line_ + distance - old_distance
|
||||
end
|
||||
return res
|
||||
end
|
||||
|
||||
for _, rough_diff_block in ipairs(rough_diff_blocks) do
|
||||
local error = parse_rough_diff_block(rough_diff_block, original_lines)
|
||||
if error then
|
||||
on_complete(false, error)
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
local diff_blocks = rough_diff_blocks_to_diff_blocks(rough_diff_blocks)
|
||||
|
||||
table.sort(diff_blocks, function(a, b) return a.start_line < b.start_line end)
|
||||
|
||||
local base_line = 0
|
||||
for _, diff_block in ipairs(diff_blocks) do
|
||||
diff_block.new_start_line = diff_block.start_line + base_line
|
||||
diff_block.new_end_line = diff_block.new_start_line + #diff_block.new_lines - 1
|
||||
base_line = base_line + #diff_block.new_lines - #diff_block.old_lines
|
||||
end
|
||||
|
||||
local function remove_diff_block(removed_idx, use_new_lines)
|
||||
local new_diff_blocks = {}
|
||||
local distance = 0
|
||||
for idx, diff_block in ipairs(diff_blocks) do
|
||||
if idx == removed_idx then
|
||||
if not use_new_lines then distance = #diff_block.old_lines - #diff_block.new_lines end
|
||||
goto continue
|
||||
end
|
||||
if idx > removed_idx then
|
||||
diff_block.new_start_line = diff_block.new_start_line + distance
|
||||
diff_block.new_end_line = diff_block.new_end_line + distance
|
||||
end
|
||||
table.insert(new_diff_blocks, diff_block)
|
||||
::continue::
|
||||
end
|
||||
|
||||
diff_blocks = new_diff_blocks
|
||||
end
|
||||
|
||||
local function get_current_diff_block()
|
||||
local winid = Utils.get_winid(bufnr)
|
||||
local cursor_line = Utils.get_cursor_pos(winid)
|
||||
for idx, diff_block in ipairs(diff_blocks) do
|
||||
if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then
|
||||
return diff_block, idx
|
||||
end
|
||||
end
|
||||
return nil, nil
|
||||
end
|
||||
|
||||
local function get_prev_diff_block()
|
||||
local winid = Utils.get_winid(bufnr)
|
||||
local cursor_line = Utils.get_cursor_pos(winid)
|
||||
local distance = nil
|
||||
local idx = nil
|
||||
for i, diff_block in ipairs(diff_blocks) do
|
||||
if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then
|
||||
local new_i = i - 1
|
||||
if new_i < 1 then return diff_blocks[#diff_blocks] end
|
||||
return diff_blocks[new_i]
|
||||
end
|
||||
if diff_block.new_start_line < cursor_line then
|
||||
local distance_ = cursor_line - diff_block.new_start_line
|
||||
if distance == nil or distance_ < distance then
|
||||
distance = distance_
|
||||
idx = i
|
||||
end
|
||||
end
|
||||
end
|
||||
if idx ~= nil then return diff_blocks[idx] end
|
||||
if #diff_blocks > 0 then return diff_blocks[#diff_blocks] end
|
||||
return nil
|
||||
end
|
||||
|
||||
local function get_next_diff_block()
|
||||
local winid = Utils.get_winid(bufnr)
|
||||
local cursor_line = Utils.get_cursor_pos(winid)
|
||||
local distance = nil
|
||||
local idx = nil
|
||||
for i, diff_block in ipairs(diff_blocks) do
|
||||
if cursor_line >= diff_block.new_start_line and cursor_line <= diff_block.new_end_line then
|
||||
local new_i = i + 1
|
||||
if new_i > #diff_blocks then return diff_blocks[1] end
|
||||
return diff_blocks[new_i]
|
||||
end
|
||||
if diff_block.new_start_line > cursor_line then
|
||||
local distance_ = diff_block.new_start_line - cursor_line
|
||||
if distance == nil or distance_ < distance then
|
||||
distance = distance_
|
||||
idx = i
|
||||
end
|
||||
end
|
||||
end
|
||||
if idx ~= nil then return diff_blocks[idx] end
|
||||
if #diff_blocks > 0 then return diff_blocks[1] end
|
||||
return nil
|
||||
end
|
||||
|
||||
local show_keybinding_hint_extmark_id = nil
|
||||
local function register_cursor_move_events()
|
||||
local function show_keybinding_hint(lnum)
|
||||
if show_keybinding_hint_extmark_id then
|
||||
vim.api.nvim_buf_del_extmark(bufnr, KEYBINDING_NAMESPACE, show_keybinding_hint_extmark_id)
|
||||
end
|
||||
|
||||
local hint = string.format(
|
||||
"[<%s>: OURS, <%s>: THEIRS, <%s>: PREV, <%s>: NEXT]",
|
||||
Config.mappings.diff.ours,
|
||||
Config.mappings.diff.theirs,
|
||||
Config.mappings.diff.prev,
|
||||
Config.mappings.diff.next
|
||||
)
|
||||
|
||||
show_keybinding_hint_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, KEYBINDING_NAMESPACE, lnum - 1, -1, {
|
||||
hl_group = "AvanteInlineHint",
|
||||
virt_text = { { hint, "AvanteInlineHint" } },
|
||||
virt_text_pos = "right_align",
|
||||
priority = PRIORITY,
|
||||
})
|
||||
end
|
||||
|
||||
vim.api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI", "WinLeave" }, {
|
||||
buffer = bufnr,
|
||||
callback = function(event)
|
||||
local diff_block = get_current_diff_block()
|
||||
if (event.event == "CursorMoved" or event.event == "CursorMovedI") and diff_block then
|
||||
show_keybinding_hint(diff_block.new_start_line)
|
||||
else
|
||||
vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1)
|
||||
end
|
||||
end,
|
||||
})
|
||||
end
|
||||
|
||||
local function register_keybinding_events()
|
||||
vim.keymap.set({ "n", "v" }, Config.mappings.diff.ours, function()
|
||||
if vim.api.nvim_get_current_buf() ~= bufnr then return end
|
||||
local diff_block, idx = get_current_diff_block()
|
||||
if not diff_block then return end
|
||||
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id)
|
||||
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id)
|
||||
vim.api.nvim_buf_set_lines(
|
||||
bufnr,
|
||||
diff_block.new_start_line - 1,
|
||||
diff_block.new_end_line,
|
||||
false,
|
||||
diff_block.old_lines
|
||||
)
|
||||
diff_block.incoming_extmark_id = nil
|
||||
diff_block.delete_extmark_id = nil
|
||||
remove_diff_block(idx, false)
|
||||
local next_diff_block = get_next_diff_block()
|
||||
if next_diff_block then
|
||||
local winnr = Utils.get_winid(bufnr)
|
||||
vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
end
|
||||
end)
|
||||
|
||||
vim.keymap.set({ "n", "v" }, Config.mappings.diff.theirs, function()
|
||||
if vim.api.nvim_get_current_buf() ~= bufnr then return end
|
||||
local diff_block, idx = get_current_diff_block()
|
||||
if not diff_block then return end
|
||||
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.incoming_extmark_id)
|
||||
pcall(vim.api.nvim_buf_del_extmark, bufnr, NAMESPACE, diff_block.delete_extmark_id)
|
||||
diff_block.incoming_extmark_id = nil
|
||||
diff_block.delete_extmark_id = nil
|
||||
remove_diff_block(idx, true)
|
||||
local next_diff_block = get_next_diff_block()
|
||||
if next_diff_block then
|
||||
local winnr = Utils.get_winid(bufnr)
|
||||
vim.api.nvim_win_set_cursor(winnr, { next_diff_block.new_start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
end
|
||||
end)
|
||||
|
||||
vim.keymap.set({ "n", "v" }, Config.mappings.diff.next, function()
|
||||
if vim.api.nvim_get_current_buf() ~= bufnr then return end
|
||||
local diff_block = get_next_diff_block()
|
||||
if not diff_block then return end
|
||||
local winnr = Utils.get_winid(bufnr)
|
||||
vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
end)
|
||||
|
||||
vim.keymap.set({ "n", "v" }, Config.mappings.diff.prev, function()
|
||||
if vim.api.nvim_get_current_buf() ~= bufnr then return end
|
||||
local diff_block = get_prev_diff_block()
|
||||
if not diff_block then return end
|
||||
local winnr = Utils.get_winid(bufnr)
|
||||
vim.api.nvim_win_set_cursor(winnr, { diff_block.new_start_line, 0 })
|
||||
vim.api.nvim_win_call(winnr, function() vim.cmd("normal! zz") end)
|
||||
end)
|
||||
end
|
||||
|
||||
local function unregister_keybinding_events()
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.ours)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.theirs)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.next)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "n", Config.mappings.diff.prev)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.ours)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.theirs)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.next)
|
||||
pcall(vim.api.nvim_buf_del_keymap, bufnr, "v", Config.mappings.diff.prev)
|
||||
end
|
||||
|
||||
local augroup = vim.api.nvim_create_augroup("avante_replace_in_file", { clear = true })
|
||||
local function clear()
|
||||
if bufnr and not vim.api.nvim_buf_is_valid(bufnr) then return end
|
||||
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
|
||||
vim.api.nvim_buf_clear_namespace(bufnr, KEYBINDING_NAMESPACE, 0, -1)
|
||||
unregister_keybinding_events()
|
||||
pcall(vim.api.nvim_del_augroup_by_id, augroup)
|
||||
end
|
||||
|
||||
local function insert_diff_blocks_new_lines()
|
||||
local base_line_ = 0
|
||||
for _, diff_block in ipairs(diff_blocks) do
|
||||
local start_line = diff_block.start_line + base_line_
|
||||
local end_line = diff_block.end_line + base_line_
|
||||
base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines
|
||||
vim.api.nvim_buf_set_lines(bufnr, start_line - 1, end_line, false, diff_block.new_lines)
|
||||
end
|
||||
end
|
||||
|
||||
local function highlight_diff_blocks()
|
||||
vim.api.nvim_buf_clear_namespace(bufnr, NAMESPACE, 0, -1)
|
||||
local base_line_ = 0
|
||||
local max_col = vim.o.columns
|
||||
for _, diff_block in ipairs(diff_blocks) do
|
||||
local start_line = diff_block.start_line + base_line_
|
||||
base_line_ = base_line_ + #diff_block.new_lines - #diff_block.old_lines
|
||||
local deleted_virt_lines = vim
|
||||
.iter(diff_block.old_lines)
|
||||
:map(function(line)
|
||||
--- append spaces to the end of the line
|
||||
local line_ = line .. string.rep(" ", max_col - #line)
|
||||
return { { line_, Highlights.TO_BE_DELETED_WITHOUT_STRIKETHROUGH } }
|
||||
end)
|
||||
:totable()
|
||||
local extmark_line = math.max(0, start_line - 2)
|
||||
local delete_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, extmark_line, 0, {
|
||||
virt_lines = deleted_virt_lines,
|
||||
hl_eol = true,
|
||||
hl_mode = "combine",
|
||||
})
|
||||
local end_row = start_line + #diff_block.new_lines - 1
|
||||
local incoming_extmark_id = vim.api.nvim_buf_set_extmark(bufnr, NAMESPACE, start_line - 1, 0, {
|
||||
hl_group = Highlights.INCOMING,
|
||||
hl_eol = true,
|
||||
hl_mode = "combine",
|
||||
end_row = end_row,
|
||||
})
|
||||
diff_block.delete_extmark_id = delete_extmark_id
|
||||
diff_block.incoming_extmark_id = incoming_extmark_id
|
||||
end
|
||||
end
|
||||
|
||||
insert_diff_blocks_new_lines()
|
||||
highlight_diff_blocks()
|
||||
register_cursor_move_events()
|
||||
register_keybinding_events()
|
||||
|
||||
Helpers.confirm("Are you sure you want to apply this modification?", function(ok, reason)
|
||||
clear()
|
||||
if not ok then
|
||||
vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, original_lines)
|
||||
on_complete(false, "User declined, reason: " .. (reason or "unknown"))
|
||||
return
|
||||
end
|
||||
vim.api.nvim_buf_call(bufnr, function() vim.cmd("noautocmd write") end)
|
||||
if session_ctx then Helpers.mark_as_not_viewed(opts.path, session_ctx) end
|
||||
on_complete(true, nil)
|
||||
end, { focus = not Config.behaviour.auto_focus_on_diff_view }, session_ctx)
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -13,7 +13,7 @@ M.name = "str_replace"
|
||||
M.description =
|
||||
"The str_replace tool allows you to replace a specific string in a file with a new string. This is used for making precise edits."
|
||||
|
||||
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end
|
||||
function M.enabled() return false end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
@@ -65,8 +65,8 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not file then return false, "file not found: " .. abs_path end
|
||||
if opts.old_str == nil then return false, "old_str not provided" end
|
||||
if opts.new_str == nil then return false, "new_str not provided" end
|
||||
Utils.debug("old_str", opts.old_str)
|
||||
Utils.debug("new_str", opts.new_str)
|
||||
-- Utils.debug("old_str", opts.old_str)
|
||||
-- Utils.debug("new_str", opts.new_str)
|
||||
local bufnr, err = Helpers.get_bufnr(abs_path)
|
||||
if err then return false, err end
|
||||
local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)
|
||||
@@ -77,7 +77,7 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
for i = 1, #lines - #old_lines + 1 do
|
||||
local match = true
|
||||
for j = 1, #old_lines do
|
||||
if lines[i + j - 1] ~= old_lines[j] then
|
||||
if Utils.remove_indentation(lines[i + j - 1]) ~= Utils.remove_indentation(old_lines[j]) then
|
||||
match = false
|
||||
break
|
||||
end
|
||||
@@ -89,11 +89,20 @@ function M.func(opts, on_log, on_complete, session_ctx)
|
||||
end
|
||||
end
|
||||
if start_line == nil or end_line == nil then
|
||||
on_complete(false, "Failed to find the old string: " .. opts.old_str)
|
||||
on_complete(false, "Failed to find the old string:\n" .. opts.old_str)
|
||||
return
|
||||
end
|
||||
local old_str = opts.old_str
|
||||
local new_str = opts.new_str
|
||||
local original_indentation = Utils.get_indentation(lines[start_line])
|
||||
if original_indentation ~= Utils.get_indentation(old_lines[1]) then
|
||||
old_lines = vim.tbl_map(function(line) return original_indentation .. line end, old_lines)
|
||||
new_lines = vim.tbl_map(function(line) return original_indentation .. line end, new_lines)
|
||||
old_str = table.concat(old_lines, "\n")
|
||||
new_str = table.concat(new_lines, "\n")
|
||||
end
|
||||
---@diagnostic disable-next-line: assign-type-mismatch, missing-fields
|
||||
local patch = vim.diff(opts.old_str, opts.new_str, { ---@type integer[][]
|
||||
local patch = vim.diff(old_str, new_str, { ---@type integer[][]
|
||||
algorithm = "histogram",
|
||||
result_type = "indices",
|
||||
ctxlen = vim.o.scrolloff,
|
||||
|
||||
@@ -10,7 +10,7 @@ M.name = "undo_edit"
|
||||
|
||||
M.description = "The undo_edit tool allows you to revert the last edit made to a file."
|
||||
|
||||
function M.enabled() return require("avante.config").behaviour.enable_claude_text_editor_tool_mode end
|
||||
function M.enabled() return require("avante.config").mode == "agentic" end
|
||||
|
||||
---@type AvanteLLMToolParam
|
||||
M.param = {
|
||||
|
||||
@@ -76,17 +76,6 @@ M.returns = {
|
||||
function M.func(opts, on_log, on_complete, session_ctx)
|
||||
if not on_complete then return false, "on_complete not provided" end
|
||||
if on_log then on_log("path: " .. opts.path) end
|
||||
if Helpers.already_in_context(opts.path) then
|
||||
on_complete(nil, "Ooooops! This file is already in the context! Why you are trying to read it again?")
|
||||
return
|
||||
end
|
||||
if session_ctx then
|
||||
if Helpers.already_viewed(opts.path, session_ctx) then
|
||||
on_complete(nil, "Ooooops! You have already viewed this file! Why you are trying to read it again?")
|
||||
return
|
||||
end
|
||||
Helpers.mark_as_viewed(opts.path, session_ctx)
|
||||
end
|
||||
local abs_path = Helpers.get_abs_path(opts.path)
|
||||
if not Helpers.has_permission_to_access(abs_path) then return false, "No permission to access path: " .. abs_path end
|
||||
if not Path:new(abs_path):exists() then return false, "Path not found: " .. abs_path end
|
||||
|
||||
@@ -16,9 +16,9 @@ local function generate_project_dirname_in_storage(bufnr)
|
||||
buf = bufnr,
|
||||
})
|
||||
-- Replace path separators with double underscores
|
||||
local path_with_separators = fn.substitute(project_root, "/", "__", "g")
|
||||
local path_with_separators = string.gsub(project_root, "/", "__")
|
||||
-- Replace other non-alphanumeric characters with single underscores
|
||||
local dirname = fn.substitute(path_with_separators, "[^A-Za-z0-9._]", "_", "g")
|
||||
local dirname = string.gsub(path_with_separators, "[^A-Za-z0-9._]", "_")
|
||||
return tostring(Path:new("projects"):joinpath(dirname))
|
||||
end
|
||||
|
||||
@@ -34,6 +34,7 @@ function History.get_history_dir(bufnr)
|
||||
return history_dir
|
||||
end
|
||||
|
||||
---@return avante.ChatHistory[]
|
||||
function History.list(bufnr)
|
||||
local history_dir = History.get_history_dir(bufnr)
|
||||
local files = vim.fn.glob(tostring(history_dir:joinpath("*.json")), true, true)
|
||||
@@ -53,8 +54,10 @@ function History.list(bufnr)
|
||||
table.sort(res, function(a, b)
|
||||
if a.filename == latest_filename then return true end
|
||||
if b.filename == latest_filename then return false end
|
||||
local timestamp_a = #a.entries > 0 and a.entries[#a.entries].timestamp or a.timestamp
|
||||
local timestamp_b = #b.entries > 0 and b.entries[#b.entries].timestamp or b.timestamp
|
||||
local a_messages = Utils.get_history_messages(a)
|
||||
local b_messages = Utils.get_history_messages(b)
|
||||
local timestamp_a = #a_messages > 0 and a_messages[#a_messages].timestamp or a.timestamp
|
||||
local timestamp_b = #b_messages > 0 and b_messages[#b_messages].timestamp or b.timestamp
|
||||
return timestamp_a > timestamp_b
|
||||
end)
|
||||
return res
|
||||
@@ -117,8 +120,8 @@ function History.new(bufnr)
|
||||
---@type avante.ChatHistory
|
||||
local history = {
|
||||
title = "untitled",
|
||||
timestamp = tostring(os.date("%Y-%m-%d %H:%M:%S")),
|
||||
entries = {},
|
||||
timestamp = Utils.get_timestamp(),
|
||||
messages = {},
|
||||
filename = filepath_to_filename(filepath),
|
||||
}
|
||||
return history
|
||||
@@ -169,11 +172,10 @@ function Prompt.get_builtin_prompts_filepath(mode) return string.format("%s.avan
|
||||
local _templates_lib = nil
|
||||
|
||||
Prompt.custom_modes = {
|
||||
planning = true,
|
||||
agentic = true,
|
||||
legacy = true,
|
||||
editing = true,
|
||||
suggesting = true,
|
||||
["cursor-planning"] = true,
|
||||
["cursor-applying"] = true,
|
||||
}
|
||||
|
||||
Prompt.custom_prompts_contents = {}
|
||||
|
||||
@@ -58,6 +58,7 @@ function M:parse_stream_data(ctx, data, opts)
|
||||
end
|
||||
|
||||
function M:parse_response_without_stream(data, event_state, opts)
|
||||
if opts.on_chunk == nil then return end
|
||||
local bedrock_match = data:gmatch("exception(%b{})")
|
||||
opts.on_chunk("\n**Exception caught**\n\n")
|
||||
for bedrock_data_match in bedrock_match do
|
||||
|
||||
@@ -2,7 +2,7 @@ local Utils = require("avante.utils")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local P = require("avante.providers")
|
||||
local Config = require("avante.config")
|
||||
local StreamingJsonParser = require("avante.utils.streaming_json_parser")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -139,63 +139,6 @@ function M:parse_messages(opts)
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
if tool_history.tool_use then
|
||||
local msg = {
|
||||
role = "assistant",
|
||||
content = {},
|
||||
}
|
||||
if tool_history.tool_use.thinking_blocks then
|
||||
for _, thinking_block in ipairs(tool_history.tool_use.thinking_blocks) do
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "thinking",
|
||||
thinking = thinking_block.thinking,
|
||||
signature = thinking_block.signature,
|
||||
}
|
||||
end
|
||||
end
|
||||
if tool_history.tool_use.redacted_thinking_blocks then
|
||||
for _, redacted_thinking_block in ipairs(tool_history.tool_use.redacted_thinking_blocks) do
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "redacted_thinking",
|
||||
data = redacted_thinking_block.data,
|
||||
}
|
||||
end
|
||||
end
|
||||
if tool_history.tool_use.response_contents then
|
||||
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "text",
|
||||
text = response_content,
|
||||
}
|
||||
end
|
||||
end
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "tool_use",
|
||||
id = tool_history.tool_use.id,
|
||||
name = tool_history.tool_use.name,
|
||||
input = vim.json.decode(tool_history.tool_use.input_json),
|
||||
}
|
||||
messages[#messages + 1] = msg
|
||||
end
|
||||
|
||||
if tool_history.tool_result then
|
||||
messages[#messages + 1] = {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.content,
|
||||
is_error = tool_history.tool_result.is_error,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
@@ -226,14 +169,51 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
local content_block = jsn.content_block
|
||||
content_block.stoppped = false
|
||||
ctx.content_blocks[jsn.index + 1] = content_block
|
||||
if content_block.type == "thinking" then opts.on_chunk("<think>\n") end
|
||||
if content_block.type == "tool_use" and opts.on_partial_tool_use then
|
||||
opts.on_partial_tool_use({
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
partial_json = {},
|
||||
if content_block.type == "text" then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generating",
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
if content_block.type == "thinking" then
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
if opts.on_messages_add then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = content_block.thinking,
|
||||
signature = content_block.signature,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
end
|
||||
if content_block.type == "tool_use" and opts.on_messages_add then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
input = {},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
elseif event_state == "content_block_delta" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
@@ -242,23 +222,35 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if jsn.delta.type == "input_json_delta" then
|
||||
if not content_block.input_json then content_block.input_json = "" end
|
||||
content_block.input_json = content_block.input_json .. jsn.delta.partial_json
|
||||
if opts.on_partial_tool_use then
|
||||
local streaming_json_parser = StreamingJsonParser:new()
|
||||
local partial_json = streaming_json_parser:parse(content_block.input_json)
|
||||
opts.on_partial_tool_use({
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
partial_json = partial_json or {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
return
|
||||
elseif jsn.delta.type == "thinking_delta" then
|
||||
content_block.thinking = content_block.thinking .. jsn.delta.thinking
|
||||
opts.on_chunk(jsn.delta.thinking)
|
||||
if opts.on_chunk then opts.on_chunk(jsn.delta.thinking) end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = content_block.thinking,
|
||||
signature = content_block.signature,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "text_delta" then
|
||||
content_block.text = content_block.text .. jsn.delta.text
|
||||
opts.on_chunk(jsn.delta.text)
|
||||
if opts.on_chunk then opts.on_chunk(jsn.delta.text) end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "signature_delta" then
|
||||
if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end
|
||||
ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature
|
||||
@@ -268,12 +260,56 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if not ok then return end
|
||||
local content_block = ctx.content_blocks[jsn.index + 1]
|
||||
content_block.stoppped = true
|
||||
if content_block.type == "text" then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
if content_block.type == "tool_use" then
|
||||
local complete_json = vim.json.decode(content_block.input_json)
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
input = complete_json or {},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
if content_block.type == "thinking" then
|
||||
if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then
|
||||
opts.on_chunk("\n</think>\n\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n\n")
|
||||
if opts.on_chunk then
|
||||
if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then
|
||||
opts.on_chunk("\n</think>\n\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n\n")
|
||||
end
|
||||
end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = content_block.thinking,
|
||||
signature = content_block.signature,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
elseif event_state == "message_delta" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
@@ -281,49 +317,20 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if jsn.delta.stop_reason == "end_turn" then
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
elseif jsn.delta.stop_reason == "tool_use" then
|
||||
---@type AvanteLLMToolUse[]
|
||||
local tool_use_list = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(function(content_block) return content_block.stoppped and content_block.type == "tool_use" end)
|
||||
:map(function(content_block)
|
||||
local response_contents = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end)
|
||||
:map(function(content_block_) return content_block_.text end)
|
||||
:totable()
|
||||
local thinking_blocks = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end)
|
||||
:map(function(content_block_)
|
||||
---@type AvanteLLMThinkingBlock
|
||||
return { thinking = content_block_.thinking, signature = content_block_.signature }
|
||||
end)
|
||||
:totable()
|
||||
local redacted_thinking_blocks = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(
|
||||
function(content_block_) return content_block_.stoppped and content_block_.type == "redacted_thinking" end
|
||||
)
|
||||
:map(function(content_block_)
|
||||
---@type AvanteLLMRedactedThinkingBlock
|
||||
return { data = content_block_.data }
|
||||
end)
|
||||
:totable()
|
||||
---@type AvanteLLMToolUse
|
||||
return {
|
||||
name = content_block.name,
|
||||
local tool_use_list = {}
|
||||
for _, content_block in ipairs(ctx.content_blocks) do
|
||||
if content_block.type == "tool_use" then
|
||||
table.insert(tool_use_list, {
|
||||
id = content_block.id,
|
||||
name = content_block.name,
|
||||
input_json = content_block.input_json,
|
||||
response_contents = response_contents,
|
||||
thinking_blocks = thinking_blocks,
|
||||
redacted_thinking_blocks = redacted_thinking_blocks,
|
||||
}
|
||||
end)
|
||||
:totable()
|
||||
})
|
||||
end
|
||||
end
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
-- tool_use_list = tool_use_list,
|
||||
usage = jsn.usage,
|
||||
tool_use_list = tool_use_list,
|
||||
})
|
||||
end
|
||||
return
|
||||
@@ -351,7 +358,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
local tools = {}
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
if Config.behaviour.enable_claude_text_editor_tool_mode then
|
||||
if Config.mode == "agentic" then
|
||||
if tool.name == "create_file" then goto continue end
|
||||
if tool.name == "view" then goto continue end
|
||||
if tool.name == "str_replace" then goto continue end
|
||||
@@ -364,7 +371,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
end
|
||||
end
|
||||
|
||||
if prompt_opts.tools and #prompt_opts.tools > 0 and Config.behaviour.enable_claude_text_editor_tool_mode then
|
||||
if prompt_opts.tools and #prompt_opts.tools > 0 and Config.mode == "agentic" then
|
||||
if provider_conf.model:match("claude%-3%-7%-sonnet") then
|
||||
table.insert(tools, {
|
||||
type = "text_editor_20250124",
|
||||
|
||||
@@ -211,11 +211,7 @@ M.role_map = {
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
M.parse_messages = OpenAI.parse_messages
|
||||
|
||||
M.parse_response = OpenAI.parse_response
|
||||
|
||||
M.is_reasoning_model = OpenAI.is_reasoning_model
|
||||
setmetatable(M, { __index = OpenAI })
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
-- refresh token synchronously, only if it has expired
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local OpenAI = require("avante.providers").openai
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -10,14 +11,32 @@ M.role_map = {
|
||||
user = "user",
|
||||
assistant = "model",
|
||||
}
|
||||
-- M.tokenizer_id = "google/gemma-2b"
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
---@param tool AvanteLLMTool
|
||||
function M:transform_to_function_declaration(tool)
|
||||
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
|
||||
local parameters = nil
|
||||
if not vim.tbl_isempty(input_schema_properties) then
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = input_schema_properties,
|
||||
required = required,
|
||||
}
|
||||
end
|
||||
return {
|
||||
name = tool.name,
|
||||
description = tool.get_description and tool.get_description() or tool.description,
|
||||
parameters = parameters,
|
||||
}
|
||||
end
|
||||
|
||||
function M:parse_messages(opts)
|
||||
local contents = {}
|
||||
local prev_role = nil
|
||||
|
||||
local tool_id_to_name = {}
|
||||
vim.iter(opts.messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role then
|
||||
@@ -54,9 +73,27 @@ function M:parse_messages(opts)
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_use" then
|
||||
table.insert(parts, { text = item.name })
|
||||
tool_id_to_name[item.id] = item.name
|
||||
role = "model"
|
||||
table.insert(parts, {
|
||||
functionCall = {
|
||||
name = item.name,
|
||||
args = item.input,
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_result" then
|
||||
table.insert(parts, { text = item.content })
|
||||
role = "function"
|
||||
local ok, content = pcall(vim.json.decode, item.content)
|
||||
if not ok then content = item.content end
|
||||
table.insert(parts, {
|
||||
functionResponse = {
|
||||
name = tool_id_to_name[item.tool_use_id],
|
||||
response = {
|
||||
name = tool_id_to_name[item.tool_use_id],
|
||||
content = content,
|
||||
},
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "thinking" then
|
||||
table.insert(parts, { text = item.thinking })
|
||||
elseif type(item) == "table" and item.type == "redacted_thinking" then
|
||||
@@ -96,22 +133,43 @@ end
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
local ok, json = pcall(vim.json.decode, data_stream)
|
||||
if not ok then opts.on_stop({ reason = "error", error = json }) end
|
||||
if json.candidates then
|
||||
if #json.candidates > 0 then
|
||||
if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then
|
||||
opts.on_chunk(json.candidates[1].content.parts[1].text)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
else
|
||||
opts.on_chunk(json.candidates[1].content.parts[1].text)
|
||||
if json.candidates and #json.candidates > 0 then
|
||||
local candidate = json.candidates[1]
|
||||
---@type AvanteLLMToolUse[]
|
||||
local tool_use_list = {}
|
||||
for _, part in ipairs(candidate.content.parts) do
|
||||
if part.text then
|
||||
if opts.on_chunk then opts.on_chunk(part.text) end
|
||||
OpenAI:add_text_message(ctx, part.text, "generating", opts)
|
||||
elseif part.functionCall then
|
||||
if not ctx.function_call_id then ctx.function_call_id = 0 end
|
||||
ctx.function_call_id = ctx.function_call_id + 1
|
||||
local tool_use = {
|
||||
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
|
||||
name = part.functionCall.name,
|
||||
input_json = vim.json.encode(part.functionCall.args),
|
||||
}
|
||||
table.insert(tool_use_list, tool_use)
|
||||
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
||||
end
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
if candidate.finishReason and candidate.finishReason == "STOP" then
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
if #tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
end
|
||||
else
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
generationConfig = {
|
||||
@@ -125,6 +183,21 @@ function M:parse_curl_args(prompt_opts)
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||
|
||||
local function_declarations = {}
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(function_declarations, self:transform_to_function_declaration(tool))
|
||||
end
|
||||
end
|
||||
|
||||
if #function_declarations > 0 then
|
||||
request_body.tools = {
|
||||
{
|
||||
functionDeclarations = function_declarations,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(
|
||||
provider_conf.endpoint,
|
||||
|
||||
@@ -215,14 +215,6 @@ function M.setup()
|
||||
E.setup({ provider = auto_suggestions_provider })
|
||||
end
|
||||
|
||||
if Config.behaviour.enable_cursor_planning_mode then
|
||||
local cursor_applying_provider_name = Config.cursor_applying_provider or Config.provider
|
||||
local cursor_applying_provider = M[cursor_applying_provider_name]
|
||||
if cursor_applying_provider and cursor_applying_provider ~= provider then
|
||||
E.setup({ provider = cursor_applying_provider })
|
||||
end
|
||||
end
|
||||
|
||||
if Config.memory_summary_provider then
|
||||
local memory_summary_provider = M[Config.memory_summary_provider]
|
||||
if memory_summary_provider and memory_summary_provider ~= provider then
|
||||
@@ -277,4 +269,13 @@ function M.get_config(provider_name)
|
||||
return type(cur) == "function" and cur() or cur
|
||||
end
|
||||
|
||||
function M.get_memory_summary_provider()
|
||||
local provider_name = Config.memory_summary_provider
|
||||
if provider_name == nil then
|
||||
if M.openai.is_env_set() then provider_name = "openai-gpt-4o-mini" end
|
||||
end
|
||||
if provider_name == nil then provider_name = Config.provider end
|
||||
return M[provider_name]
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -16,7 +16,7 @@ M.is_reasoning_model = P.openai.is_reasoning_model
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
function M:parse_stream_data(ctx, data, handler_opts)
|
||||
function M:parse_stream_data(ctx, data, opts)
|
||||
local ok, json_data = pcall(vim.json.decode, data)
|
||||
if not ok or not json_data then
|
||||
-- Add debug logging
|
||||
@@ -26,11 +26,13 @@ function M:parse_stream_data(ctx, data, handler_opts)
|
||||
|
||||
if json_data.message and json_data.message.content then
|
||||
local content = json_data.message.content
|
||||
if content and content ~= "" then handler_opts.on_chunk(content) end
|
||||
P.openai:add_text_message(ctx, content, "generating", opts)
|
||||
if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end
|
||||
end
|
||||
|
||||
if json_data.done then
|
||||
handler_opts.on_stop({ reason = "complete" })
|
||||
P.openai:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
@@ -2,7 +2,7 @@ local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local Providers = require("avante.providers")
|
||||
local StreamingJsonParser = require("avante.utils.streaming_json_parser")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -164,117 +164,154 @@ function M:parse_messages(opts)
|
||||
table.insert(final_messages, message)
|
||||
end)
|
||||
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
table.insert(final_messages, {
|
||||
role = self.role_map["assistant"],
|
||||
tool_calls = {
|
||||
{
|
||||
id = tool_history.tool_use.id,
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool_history.tool_use.name,
|
||||
arguments = tool_history.tool_use.input_json,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
local result_content = tool_history.tool_result.content or ""
|
||||
table.insert(final_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content,
|
||||
})
|
||||
return final_messages
|
||||
end
|
||||
|
||||
function M:finish_pending_messages(ctx, opts)
|
||||
if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end
|
||||
if ctx.tool_use_list then
|
||||
for _, tool_use in ipairs(ctx.tool_use_list) do
|
||||
if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return final_messages
|
||||
function M:add_text_message(ctx, text, state, opts)
|
||||
if ctx.content == nil then ctx.content = "" end
|
||||
ctx.content = ctx.content .. text
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = ctx.content,
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.content_uuid,
|
||||
})
|
||||
ctx.content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
if ctx.reasonging_content == nil then ctx.reasonging_content = "" end
|
||||
ctx.reasonging_content = ctx.reasonging_content .. text
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = ctx.reasonging_content,
|
||||
signature = "",
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.reasonging_content_uuid,
|
||||
})
|
||||
ctx.reasonging_content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:add_tool_use_message(tool_use, state, opts)
|
||||
local jsn = nil
|
||||
if state == "generated" then jsn = vim.json.decode(tool_use.input_json) end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = tool_use.name,
|
||||
id = tool_use.id,
|
||||
input = jsn or {},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = tool_use.uuid,
|
||||
})
|
||||
tool_use.uuid = msg.uuid
|
||||
tool_use.state = state
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
end
|
||||
if data_stream:match('"delta":') then
|
||||
---@type AvanteOpenAIChatResponse
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
if jsn.choices and jsn.choices[1] then
|
||||
local choice = jsn.choices[1]
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
if choice.delta.content and choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end
|
||||
opts.on_stop({ reason = "complete" })
|
||||
elseif choice.finish_reason == "tool_calls" then
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
usage = jsn.usage,
|
||||
tool_use_list = ctx.tool_use_list,
|
||||
})
|
||||
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
opts.on_chunk("<think>\n")
|
||||
if not data_stream:match('"delta":') then return end
|
||||
---@type AvanteOpenAIChatResponse
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
if not jsn.choices or not jsn.choices[1] then return end
|
||||
local choice = jsn.choices[1]
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
if choice.delta.content and choice.delta.content ~= vim.NIL then
|
||||
self:add_text_message(ctx, choice.delta.content, "generated", opts)
|
||||
opts.on_chunk(choice.delta.content)
|
||||
end
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
elseif choice.finish_reason == "tool_calls" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
-- tool_use_list = ctx.tool_use_list,
|
||||
usage = jsn.usage,
|
||||
})
|
||||
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning_content
|
||||
self:add_thinking_message(ctx, choice.delta.reasoning_content, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning_content) end
|
||||
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning
|
||||
self:add_thinking_message(ctx, choice.delta.reasoning, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning) end
|
||||
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
|
||||
for _, tool_call in ipairs(choice.delta.tool_calls) do
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then
|
||||
local prev_tool_use = ctx.tool_use_list[tool_call.index]
|
||||
self:add_tool_use_message(prev_tool_use, "generated", opts)
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning_content
|
||||
opts.on_chunk(choice.delta.reasoning_content)
|
||||
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
opts.on_chunk("<think>\n")
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning
|
||||
opts.on_chunk(choice.delta.reasoning)
|
||||
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
|
||||
for _, tool_call in ipairs(choice.delta.tool_calls) do
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
if opts.on_partial_tool_use then
|
||||
opts.on_partial_tool_use({
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
partial_json = {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
if opts.on_partial_tool_use then
|
||||
local parser = StreamingJsonParser:new()
|
||||
local partial_json = parser:parse(tool_use.input_json)
|
||||
opts.on_partial_tool_use({
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
partial_json = partial_json or {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif choice.delta.content then
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
ctx.returned_think_end_tag = true
|
||||
if
|
||||
ctx.last_think_content
|
||||
and ctx.last_think_content ~= vim.NIL
|
||||
and ctx.last_think_content:sub(-1) ~= "\n"
|
||||
then
|
||||
opts.on_chunk("\n</think>\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n")
|
||||
end
|
||||
end
|
||||
if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
end
|
||||
end
|
||||
elseif choice.delta.content then
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
ctx.returned_think_end_tag = true
|
||||
if opts.on_chunk then
|
||||
if ctx.last_think_content and ctx.last_think_content ~= vim.NIL and ctx.last_think_content:sub(-1) ~= "\n" then
|
||||
opts.on_chunk("\n</think>\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n")
|
||||
end
|
||||
end
|
||||
self:add_thinking_message(ctx, "", "generated", opts)
|
||||
end
|
||||
if choice.delta.content ~= vim.NIL then
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.content) end
|
||||
self:add_text_message(ctx, choice.delta.content, "generating", opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
10
lua/avante/templates/agentic.avanterules
Normal file
10
lua/avante/templates/agentic.avanterules
Normal file
@@ -0,0 +1,10 @@
|
||||
{% extends "base.avanterules" %}
|
||||
{% block extra_prompt %}
|
||||
Always reply to the user in the same language they are using.
|
||||
|
||||
Don't just provide code suggestions, use the `replace_in_file` tool to help users fulfill their needs.
|
||||
|
||||
After the tool call is complete, please do not output the entire file content.
|
||||
|
||||
Before calling the tool, be sure to explain the reason for calling the tool.
|
||||
{% endblock %}
|
||||
@@ -1,6 +0,0 @@
|
||||
{% extends "base.avanterules" %}
|
||||
{% block extra_prompt %}
|
||||
Always reply to the user in the same language they are using.
|
||||
|
||||
Don't just provide code suggestions, use the `str_replace` tool to help users fulfill their needs.
|
||||
{% endblock %}
|
||||
@@ -1,4 +0,0 @@
|
||||
You are a coding assistant that helps merge code updates, ensuring every modification is fully integrated.
|
||||
|
||||
{% block custom_prompt %}
|
||||
{% endblock %}
|
||||
@@ -1,46 +0,0 @@
|
||||
{% extends "base.avanterules" %}
|
||||
{%- if ask %}
|
||||
{% block extra_prompt %}
|
||||
You are an intelligent programmer, powered by {{ model_name }}. You are happy to help answer any questions that the user has (usually they will be about coding).
|
||||
|
||||
1. When the user is asking for edits to their code, please output a simplified version of the code block that highlights the changes necessary and adds comments to indicate where unchanged code has been skipped. For example:
|
||||
```language:path/to/file
|
||||
// ... existing code ...
|
||||
{% raw -%}
|
||||
{{ edit_1 }}
|
||||
{%- endraw %}
|
||||
// ... existing code ...
|
||||
{% raw -%}
|
||||
{{ edit_2 }}
|
||||
{%- endraw %}
|
||||
// ... existing code ...
|
||||
```
|
||||
The user can see the entire file, so they prefer to only read the updates to the code. Often this will mean that the start/end of the file will be skipped, but that's okay! Rewrite the entire file only if specifically requested. Always provide a brief explanation of the updates, unless the user specifically requests only the code.
|
||||
|
||||
These edit codeblocks are also read by a less intelligent language model, colloquially called the apply model, to update the file. To help specify the edit to the apply model, you will be very careful when generating the codeblock to not introduce ambiguity. You will specify all unchanged regions (code and comments) of the file with "// … existing code …" comment markers. This will ensure the apply model will not delete existing unchanged code or comments when editing the file. You will not mention the apply model.
|
||||
|
||||
2. Do not lie or make up facts.
|
||||
|
||||
3. If a user messages you in a foreign language, please respond in that language.
|
||||
|
||||
4. Format your response in markdown.
|
||||
|
||||
5. When writing out new code blocks, please specify the language ID after the initial backticks, like so:
|
||||
```python
|
||||
{% raw -%}
|
||||
{{ code }}
|
||||
{%- endraw %}
|
||||
```
|
||||
|
||||
6. When writing out code blocks for an existing file, please also specify the file path after the initial backticks and restate the method / class your codeblock belongs to, like so:
|
||||
```language:some/other/file
|
||||
function AIChatHistory() {
|
||||
...
|
||||
{% raw -%}
|
||||
{{ code }}
|
||||
{%- endraw %}
|
||||
...
|
||||
}
|
||||
```
|
||||
{% endblock %}
|
||||
{%- endif %}
|
||||
@@ -75,15 +75,32 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field on_start AvanteLLMStartCallback
|
||||
---@field on_chunk AvanteLLMChunkCallback
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
---@field on_partial_tool_use? fun(tool_use: AvantePartialLLMToolUse): nil
|
||||
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
|
||||
---@field on_state_change? fun(state: avante.GenerateState): nil
|
||||
---
|
||||
---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string }
|
||||
---
|
||||
|
||||
---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string
|
||||
---
|
||||
|
||||
---@class AvanteLLMMessage
|
||||
---@field role "user" | "assistant"
|
||||
---@field content AvanteLLMMessageContent
|
||||
|
||||
---@class avante.HistoryMessage
|
||||
---@field message AvanteLLMMessage
|
||||
---@field timestamp string
|
||||
---@field state avante.HistoryMessageState
|
||||
---@field uuid string | nil
|
||||
---@field displayed_content string | nil
|
||||
---@field visible boolean | nil
|
||||
---@field is_context boolean | nil
|
||||
---@field is_user_submission boolean | nil
|
||||
---@field provider string | nil
|
||||
---@field model string | nil
|
||||
---@field selected_code AvanteSelectedCode | nil
|
||||
---@field selected_filepaths string[] | nil
|
||||
---@field tool_use_logs string[] | nil
|
||||
---@field just_for_display boolean | nil
|
||||
---
|
||||
---@class AvanteLLMToolResult
|
||||
---@field tool_name string
|
||||
@@ -96,8 +113,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field messages AvanteLLMMessage[]
|
||||
---@field image_paths? string[]
|
||||
---@field tools? AvanteLLMTool[]
|
||||
---@field tool_histories? AvanteLLMToolHistory[]
|
||||
---@field dropped_history_messages? AvanteLLMMessage[]
|
||||
---@field dropped_history_messages? avante.HistoryMessage[]
|
||||
---
|
||||
---@class AvanteGeminiMessage
|
||||
---@field role "user"
|
||||
@@ -236,19 +252,18 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@class AvanteLLMRedactedThinkingBlock
|
||||
---@field data string
|
||||
---
|
||||
---@alias avante.HistoryMessageState "generating" | "generated"
|
||||
---
|
||||
---@class AvantePartialLLMToolUse
|
||||
---@field name string
|
||||
---@field id string
|
||||
---@field partial_json table
|
||||
---@field state "generating" | "generated"
|
||||
---@field state avante.HistoryMessageState
|
||||
---
|
||||
---@class AvanteLLMToolUse
|
||||
---@field name string
|
||||
---@field id string
|
||||
---@field input_json string
|
||||
---@field response_contents? string[]
|
||||
---@field thinking_blocks? AvanteLLMThinkingBlock[]
|
||||
---@field redacted_thinking_blocks? AvanteLLMRedactedThinkingBlock[]
|
||||
---@field input any
|
||||
---
|
||||
---@class AvanteLLMStartCallbackOptions
|
||||
---@field usage? AvanteLLMUsage
|
||||
@@ -257,10 +272,8 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field reason "complete" | "tool_use" | "error" | "rate_limit" | "cancelled"
|
||||
---@field error? string | table
|
||||
---@field usage? AvanteLLMUsage
|
||||
---@field tool_use_list? AvanteLLMToolUse[]
|
||||
---@field retry_after? integer
|
||||
---@field headers? table<string, string>
|
||||
---@field tool_histories? AvanteLLMToolHistory[]
|
||||
---
|
||||
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
|
||||
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
|
||||
@@ -303,7 +316,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field parse_response AvanteResponseParser
|
||||
---@field build_bedrock_payload AvanteBedrockPayloadBuilder
|
||||
---
|
||||
---@alias AvanteLlmMode "planning" | "editing" | "suggesting" | "cursor-planning" | "cursor-applying" | "claude-text-editor-tool"
|
||||
---@alias AvanteLlmMode avante.Mode | "editing" | "suggesting"
|
||||
---
|
||||
---@class AvanteSelectedCode
|
||||
---@field path string
|
||||
@@ -324,7 +337,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field selected_files AvanteSelectedFile[] | nil
|
||||
---@field selected_filepaths string[] | nil
|
||||
---@field diagnostics string | nil
|
||||
---@field history_messages AvanteLLMMessage[] | nil
|
||||
---@field history_messages avante.HistoryMessage[] | nil
|
||||
---@field memory string | nil
|
||||
---
|
||||
---@class AvanteGeneratePromptsOptions: AvanteTemplateOptions
|
||||
@@ -332,7 +345,6 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field mode? AvanteLlmMode
|
||||
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
|
||||
---@field tools? AvanteLLMTool[]
|
||||
---@field tool_histories? AvanteLLMToolHistory[]
|
||||
---@field original_code? string
|
||||
---@field update_snippets? string[]
|
||||
---@field prompt_opts? AvantePromptOptions
|
||||
@@ -342,9 +354,10 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field tool_result? AvanteLLMToolResult
|
||||
---@field tool_use? AvanteLLMToolUse
|
||||
---
|
||||
---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: AvanteLLMMessage[]): nil
|
||||
---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: avante.HistoryMessage[]): nil
|
||||
---
|
||||
---@alias AvanteLLMToolUseState "generating" | "generated" | "running" | "succeeded" | "failed"
|
||||
---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking"
|
||||
---
|
||||
---@class AvanteLLMStreamOptions: AvanteGeneratePromptsOptions
|
||||
---@field on_start AvanteLLMStartCallback
|
||||
@@ -352,7 +365,9 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field on_stop AvanteLLMStopCallback
|
||||
---@field on_memory_summarize? AvanteLLMMemorySummarizeCallback
|
||||
---@field on_tool_log? fun(tool_id: string, tool_name: string, log: string, state: AvanteLLMToolUseState): nil
|
||||
---@field on_partial_tool_use? fun(tool_use: AvantePartialLLMToolUse): nil
|
||||
---@field get_history_messages? fun(): avante.HistoryMessage[]
|
||||
---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil
|
||||
---@field on_state_change? fun(state: avante.GenerateState): nil
|
||||
---
|
||||
---@alias AvanteLLMToolFunc<T> fun(
|
||||
--- input: T,
|
||||
@@ -400,15 +415,14 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field original_response string
|
||||
---@field selected_file {filepath: string}?
|
||||
---@field selected_code AvanteSelectedCode | nil
|
||||
---@field reset_memory boolean?
|
||||
---@field selected_filepaths string[] | nil
|
||||
---@field visible boolean?
|
||||
---@field tool_histories? AvanteLLMToolHistory[]
|
||||
---
|
||||
---@class avante.ChatHistory
|
||||
---@field title string
|
||||
---@field timestamp string
|
||||
---@field entries avante.ChatHistoryEntry[]
|
||||
---@field messages avante.HistoryMessage[] | nil
|
||||
---@field entries avante.ChatHistoryEntry[] | nil
|
||||
---@field memory avante.ChatMemory | nil
|
||||
---@field filename string
|
||||
---@field system_prompt string | nil
|
||||
@@ -416,6 +430,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@class avante.ChatMemory
|
||||
---@field content string
|
||||
---@field last_summarized_timestamp string
|
||||
---@field last_message_uuid string | nil
|
||||
---
|
||||
---@class avante.CurlOpts
|
||||
---@field provider AvanteProviderFunctor
|
||||
@@ -427,7 +442,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field content string
|
||||
---@field uri string
|
||||
---
|
||||
---@alias AvanteSlashCommandBuiltInName "clear" | "help" | "lines" | "reset" | "commit" | "new"
|
||||
---@alias AvanteSlashCommandBuiltInName "clear" | "help" | "lines" | "commit" | "new"
|
||||
---@alias AvanteSlashCommandCallback fun(self: avante.Sidebar, args: string, cb?: fun(args: string): nil): nil
|
||||
---@class AvanteSlashCommand
|
||||
---@field name AvanteSlashCommandBuiltInName | string
|
||||
|
||||
@@ -80,7 +80,7 @@ function M.show(selector)
|
||||
|
||||
selector.on_select(selected_item_ids)
|
||||
|
||||
actions.close(prompt_bufnr)
|
||||
pcall(actions.close, prompt_bufnr)
|
||||
end)
|
||||
return true
|
||||
end,
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
|
||||
---@class avante.utils.history
|
||||
local M = {}
|
||||
|
||||
---@param entries avante.ChatHistoryEntry[]
|
||||
---@return avante.ChatHistoryEntry[]
|
||||
function M.filter_active_entries(entries)
|
||||
local entries_ = {}
|
||||
|
||||
for i = #entries, 1, -1 do
|
||||
local entry = entries[i]
|
||||
if entry.reset_memory then break end
|
||||
table.insert(entries_, 1, entry)
|
||||
end
|
||||
|
||||
return entries_
|
||||
end
|
||||
|
||||
---@param entries avante.ChatHistoryEntry[]
|
||||
---@return AvanteLLMMessage[]
|
||||
function M.entries_to_llm_messages(entries)
|
||||
local current_provider_name = Config.provider
|
||||
local messages = {}
|
||||
for _, entry in ipairs(entries) do
|
||||
if entry.selected_filepaths ~= nil and #entry.selected_filepaths > 0 then
|
||||
local user_content = "SELECTED FILES:\n\n"
|
||||
for _, filepath in ipairs(entry.selected_filepaths) do
|
||||
user_content = user_content .. filepath .. "\n"
|
||||
end
|
||||
table.insert(messages, { role = "user", content = user_content })
|
||||
end
|
||||
if entry.selected_code ~= nil then
|
||||
local user_content_ = "SELECTED CODE:\n\n```"
|
||||
.. (entry.selected_code.file_type or "")
|
||||
.. (entry.selected_code.path and ":" .. entry.selected_code.path or "")
|
||||
.. "\n"
|
||||
.. entry.selected_code.content
|
||||
.. "\n```\n\n"
|
||||
table.insert(messages, { role = "user", content = user_content_ })
|
||||
end
|
||||
if entry.request ~= nil and entry.request ~= "" then
|
||||
table.insert(messages, { role = "user", content = entry.request })
|
||||
end
|
||||
if entry.tool_histories ~= nil and #entry.tool_histories > 0 and entry.provider == current_provider_name then
|
||||
for _, tool_history in ipairs(entry.tool_histories) do
|
||||
local assistant_content = {}
|
||||
if tool_history.tool_use ~= nil then
|
||||
if tool_history.tool_use.response_contents ~= nil then
|
||||
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
|
||||
table.insert(assistant_content, { type = "text", text = response_content })
|
||||
end
|
||||
end
|
||||
table.insert(assistant_content, {
|
||||
type = "tool_use",
|
||||
name = tool_history.tool_use.name,
|
||||
id = tool_history.tool_use.id,
|
||||
input = vim.json.decode(tool_history.tool_use.input_json),
|
||||
})
|
||||
end
|
||||
table.insert(messages, {
|
||||
role = "assistant",
|
||||
content = assistant_content,
|
||||
})
|
||||
local user_content = {}
|
||||
if tool_history.tool_result ~= nil and tool_history.tool_result.content ~= nil then
|
||||
table.insert(user_content, {
|
||||
type = "tool_result",
|
||||
tool_use_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.content,
|
||||
is_error = tool_history.tool_result.is_error,
|
||||
})
|
||||
end
|
||||
table.insert(messages, {
|
||||
role = "user",
|
||||
content = user_content,
|
||||
})
|
||||
end
|
||||
end
|
||||
local assistant_content = Utils.trim_think_content(entry.original_response or "")
|
||||
if assistant_content ~= "" then table.insert(messages, { role = "assistant", content = assistant_content }) end
|
||||
end
|
||||
return messages
|
||||
end
|
||||
|
||||
return M
|
||||
@@ -6,7 +6,6 @@ local lsp = vim.lsp
|
||||
---@field tokens avante.utils.tokens
|
||||
---@field root avante.utils.root
|
||||
---@field file avante.utils.file
|
||||
---@field history avante.utils.history
|
||||
---@field environment avante.utils.environment
|
||||
---@field lsp avante.utils.lsp
|
||||
local M = {}
|
||||
@@ -415,7 +414,7 @@ function M.debug(...)
|
||||
local caller_source = info.source:match("@(.+)$") or "unknown"
|
||||
local caller_module = caller_source:gsub("^.*/lua/", ""):gsub("%.lua$", ""):gsub("/", ".")
|
||||
|
||||
local timestamp = os.date("%Y-%m-%d %H:%M:%S")
|
||||
local timestamp = M.get_timestamp()
|
||||
local formated_args = {
|
||||
"[" .. timestamp .. "] [AVANTE] [DEBUG] [" .. caller_module .. ":" .. info.currentline .. "]",
|
||||
}
|
||||
@@ -1263,7 +1262,6 @@ function M.get_commands()
|
||||
local builtin_items = {
|
||||
{ description = "Show help message", name = "help" },
|
||||
{ description = "Clear chat history", name = "clear" },
|
||||
{ description = "Reset memory", name = "reset" },
|
||||
{ description = "New chat", name = "new" },
|
||||
{
|
||||
shorthelp = "Ask a question about specific lines",
|
||||
@@ -1281,7 +1279,6 @@ function M.get_commands()
|
||||
if cb then cb(args) end
|
||||
end,
|
||||
clear = function(sidebar, args, cb) sidebar:clear_history(args, cb) end,
|
||||
reset = function(sidebar, args, cb) sidebar:reset_memory(args, cb) end,
|
||||
new = function(sidebar, args, cb) sidebar:new_chat(args, cb) end,
|
||||
lines = function(_, args, cb)
|
||||
if cb then cb(args) end
|
||||
@@ -1310,4 +1307,97 @@ function M.get_commands()
|
||||
return vim.list_extend(builtin_commands, Config.slash_commands)
|
||||
end
|
||||
|
||||
---@param history avante.ChatHistory
|
||||
---@return avante.HistoryMessage[]
|
||||
function M.get_history_messages(history)
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
if history.messages then return history.messages end
|
||||
local messages = {}
|
||||
for _, entry in ipairs(history.entries or {}) do
|
||||
if entry.request and entry.request ~= "" then
|
||||
local message = HistoryMessage:new({
|
||||
role = "user",
|
||||
content = entry.request,
|
||||
}, {
|
||||
timestamp = entry.timestamp,
|
||||
is_user_submission = true,
|
||||
visible = entry.visible,
|
||||
selected_filepaths = entry.selected_filepaths,
|
||||
selected_code = entry.selected_code,
|
||||
})
|
||||
table.insert(messages, message)
|
||||
end
|
||||
if entry.response and entry.response ~= "" then
|
||||
local message = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = entry.response,
|
||||
}, {
|
||||
timestamp = entry.timestamp,
|
||||
visible = entry.visible,
|
||||
})
|
||||
table.insert(messages, message)
|
||||
end
|
||||
end
|
||||
history.messages = messages
|
||||
return messages
|
||||
end
|
||||
|
||||
function M.get_timestamp() return tostring(os.date("%Y-%m-%d %H:%M:%S")) end
|
||||
|
||||
---@param history_messages avante.HistoryMessage[]
|
||||
---@return AvanteLLMMessage[]
|
||||
function M.history_messages_to_messages(history_messages)
|
||||
local messages = {}
|
||||
for _, history_message in ipairs(history_messages) do
|
||||
if history_message.just_for_display then goto continue end
|
||||
table.insert(messages, history_message.message)
|
||||
::continue::
|
||||
end
|
||||
return messages
|
||||
end
|
||||
|
||||
function M.uuid()
|
||||
local template = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"
|
||||
return string.gsub(template, "[xy]", function(c)
|
||||
local v = (c == "x") and math.random(0, 0xf) or math.random(8, 0xb)
|
||||
return string.format("%x", v)
|
||||
end)
|
||||
end
|
||||
|
||||
---@param item AvanteLLMMessageContentItem
|
||||
---@param message avante.HistoryMessage
|
||||
---@return string
|
||||
function M.message_content_item_to_text(item, message)
|
||||
if type(item) == "string" then return item end
|
||||
if type(item) == "table" then
|
||||
if item.type == "text" then return item.text end
|
||||
if item.type == "image" then return "" end
|
||||
if item.type == "tool_use" then
|
||||
local pieces = {}
|
||||
table.insert(pieces, string.format("[%s]: calling", item.name))
|
||||
for _, log in ipairs(message.tool_use_logs or {}) do
|
||||
table.insert(pieces, log)
|
||||
end
|
||||
return table.concat(pieces, "\n")
|
||||
end
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
---@param message avante.HistoryMessage
|
||||
---@return string
|
||||
function M.message_to_text(message)
|
||||
local content = message.message.content
|
||||
if type(content) == "string" then return content end
|
||||
if vim.islist(content) then
|
||||
local pieces = {}
|
||||
for _, item in ipairs(content) do
|
||||
local text = M.message_content_item_to_text(item, message)
|
||||
if text ~= "" then table.insert(pieces, text) end
|
||||
end
|
||||
return table.concat(pieces, "\n")
|
||||
end
|
||||
return ""
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -77,10 +77,24 @@ function StreamingJSONParser:parse(chunk)
|
||||
-- Handle strings specially (they can contain JSON control characters)
|
||||
if self.state.inString then
|
||||
if self.state.escaping then
|
||||
self.state.stringBuffer = self.state.stringBuffer .. char
|
||||
local escapeMap = {
|
||||
['"'] = '"',
|
||||
["\\"] = "\\",
|
||||
["/"] = "/",
|
||||
["b"] = "\b",
|
||||
["f"] = "\f",
|
||||
["n"] = "\n",
|
||||
["r"] = "\r",
|
||||
["t"] = "\t",
|
||||
}
|
||||
local escapedChar = escapeMap[char]
|
||||
if escapedChar then
|
||||
self.state.stringBuffer = self.state.stringBuffer .. escapedChar
|
||||
else
|
||||
self.state.stringBuffer = self.state.stringBuffer .. char
|
||||
end
|
||||
self.state.escaping = false
|
||||
elseif char == "\\" then
|
||||
self.state.stringBuffer = self.state.stringBuffer .. char
|
||||
self.state.escaping = true
|
||||
elseif char == '"' then
|
||||
-- End of string
|
||||
|
||||
@@ -132,17 +132,13 @@ cmd(
|
||||
cmd("Clear", function(opts)
|
||||
local arg = vim.trim(opts.args or "")
|
||||
arg = arg == "" and "history" or arg
|
||||
if arg == "history" or arg == "memory" then
|
||||
if arg == "history" then
|
||||
local sidebar = require("avante").get()
|
||||
if not sidebar then
|
||||
Utils.error("No sidebar found")
|
||||
return
|
||||
end
|
||||
if arg == "history" then
|
||||
sidebar:clear_history()
|
||||
else
|
||||
sidebar:reset_memory()
|
||||
end
|
||||
sidebar:clear_history()
|
||||
elseif arg == "cache" then
|
||||
local P = require("avante.path")
|
||||
local history_path = P.history_path:absolute()
|
||||
|
||||
@@ -29,6 +29,13 @@ describe("StreamingJSONParser", function()
|
||||
assert.equals("value", result.key)
|
||||
end)
|
||||
|
||||
it("should parse breaklines", function()
|
||||
local result, complete = parser:parse('{"key": "value\nv"}')
|
||||
assert.is_true(complete)
|
||||
assert.is_table(result)
|
||||
assert.equals("value\nv", result.key)
|
||||
end)
|
||||
|
||||
it("should parse a complete simple JSON array", function()
|
||||
local result, complete = parser:parse("[1, 2, 3]")
|
||||
assert.is_true(complete)
|
||||
@@ -119,7 +126,7 @@ describe("StreamingJSONParser", function()
|
||||
local result, complete = parser:parse('{"text": "line1\\nline2\\t\\"quoted\\""}')
|
||||
assert.is_true(complete)
|
||||
assert.is_table(result)
|
||||
assert.equals('line1\\nline2\\t\\"quoted\\"', result.text)
|
||||
assert.equals('line1\nline2\t"quoted"', result.text)
|
||||
end)
|
||||
|
||||
it("should handle numbers correctly", function()
|
||||
|
||||
Reference in New Issue
Block a user