feat: cursor planning mode (#1289)

This commit is contained in:
yetone
2025-02-17 18:49:29 +08:00
committed by GitHub
parent b6ae4dfe7f
commit b21d2632d3
10 changed files with 429 additions and 31 deletions

View File

@@ -7,7 +7,7 @@ local curl = require("plenary.curl")
local Utils = require("avante.utils")
local Config = require("avante.config")
local Path = require("avante.path")
local P = require("avante.providers")
local Providers = require("avante.providers")
local LLMTools = require("avante.llm_tools")
---@class avante.LLM
@@ -22,16 +22,16 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
---@param opts GeneratePromptsOptions
---@return AvantePromptOptions
M.generate_prompts = function(opts)
local Provider = opts.provider or P[Config.provider]
local provider = opts.provider or Providers[Config.provider]
local mode = opts.mode or "planning"
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
local _, request_body = P.parse_config(Provider)
local _, request_body = Providers.parse_config(provider)
local max_tokens = request_body.max_tokens or 4096
-- Check if the instructions contains an image path
local image_paths = {}
local instructions = opts.instructions
if opts.instructions:match("image: ") then
if instructions and instructions:match("image: ") then
local lines = vim.split(opts.instructions, "\n")
for i, line in ipairs(lines) do
if line:match("^image: ") then
@@ -49,7 +49,7 @@ M.generate_prompts = function(opts)
local system_info = Utils.get_system_info()
local template_opts = {
use_xml_format = Provider.use_xml_format,
use_xml_format = provider.use_xml_format,
ask = opts.ask, -- TODO: add mode without ask instruction
code_lang = opts.code_lang,
selected_files = opts.selected_files,
@@ -57,6 +57,7 @@ M.generate_prompts = function(opts)
project_context = opts.project_context,
diagnostics = opts.diagnostics,
system_info = system_info,
model_name = provider.model or "unknown",
}
local system_prompt = Path.prompts.render_mode(mode, template_opts)
@@ -74,15 +75,17 @@ M.generate_prompts = function(opts)
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
end
if #opts.selected_files > 0 or opts.selected_code ~= nil then
if (opts.selected_files and #opts.selected_files > 0 or false) or opts.selected_code ~= nil then
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
end
if opts.use_xml_format then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
else
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
if instructions then
if opts.use_xml_format then
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
else
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
end
end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
@@ -110,6 +113,22 @@ M.generate_prompts = function(opts)
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
end
if opts.mode == "cursor-applying" then
local user_prompt = [[
Merge all changes from the <update> snippet into the <code> below.
- Preserve the code's structure, order, comments, and indentation exactly.
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
]]
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code)
for _, snippet in ipairs(opts.update_snippets) do
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet)
end
user_prompt = user_prompt .. "Provide the complete updated code."
table.insert(messages, { role = "user", content = user_prompt })
end
---@type AvantePromptOptions
return {
system_prompt = system_prompt,
@@ -133,7 +152,7 @@ end
---@param opts StreamOptions
M._stream = function(opts)
local Provider = opts.provider or P[Config.provider]
local provider = opts.provider or Providers[Config.provider]
local prompt_opts = M.generate_prompts(opts)
@@ -166,7 +185,7 @@ M._stream = function(opts)
}
---@type AvanteCurlOutput
local spec = Provider.parse_curl_args(Provider, prompt_opts)
local spec = provider.parse_curl_args(provider, prompt_opts)
local resp_ctx = {}
@@ -178,11 +197,11 @@ M._stream = function(opts)
return
end
local data_match = line:match("^data: (.+)$")
if data_match then Provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
if data_match then provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
end
local function parse_response_without_stream(data)
Provider.parse_response_without_stream(data, current_event_state, handler_opts)
provider.parse_response_without_stream(data, current_event_state, handler_opts)
end
local completed = false
@@ -214,17 +233,17 @@ M._stream = function(opts)
end
if not data then return end
vim.schedule(function()
if Config[Config.provider] == nil and Provider.parse_stream_data ~= nil then
if Provider.parse_response ~= nil then
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
if provider.parse_response ~= nil then
Utils.warn(
"parse_stream_data and parse_response are mutually exclusive, and thus parse_response will be ignored. Make sure that you handle the incoming data correctly.",
{ once = true }
)
end
Provider.parse_stream_data(data, handler_opts)
provider.parse_stream_data(data, handler_opts)
else
if Provider.parse_stream_data ~= nil then
Provider.parse_stream_data(data, handler_opts)
if provider.parse_stream_data ~= nil then
provider.parse_stream_data(data, handler_opts)
else
parse_stream_data(data)
end
@@ -259,8 +278,8 @@ M._stream = function(opts)
active_job = nil
cleanup()
if result.status >= 400 then
if Provider.on_error then
Provider.on_error(result)
if provider.on_error then
provider.on_error(result)
else
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
end
@@ -388,7 +407,7 @@ M._dual_boost_stream = function(opts, Provider1, Provider2)
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
end
---@alias LlmMode "planning" | "editing" | "suggesting"
---@alias LlmMode "planning" | "editing" | "suggesting" | "cursor-planning" | "cursor-applying"
---
---@class SelectedFiles
---@field path string
@@ -408,11 +427,13 @@ end
---
---@class GeneratePromptsOptions: TemplateOptions
---@field ask boolean
---@field instructions string
---@field instructions? string
---@field mode LlmMode
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
---@field tools? AvanteLLMTool[]
---@field tool_histories? AvanteLLMToolHistory[]
---@field original_code? string
---@field update_snippets? string[]
---
---@class AvanteLLMToolHistory
---@field tool_result? AvanteLLMToolResult
@@ -450,7 +471,11 @@ M.stream = function(opts)
end)
end
if Config.dual_boost.enabled and opts.mode == "planning" then
M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
M._dual_boost_stream(
opts,
Providers[Config.dual_boost.first_provider],
Providers[Config.dual_boost.second_provider]
)
else
M._stream(opts)
end