feat: cursor planning mode (#1289)
This commit is contained in:
@@ -7,7 +7,7 @@ local curl = require("plenary.curl")
|
||||
local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Path = require("avante.path")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local LLMTools = require("avante.llm_tools")
|
||||
|
||||
---@class avante.LLM
|
||||
@@ -22,16 +22,16 @@ local group = api.nvim_create_augroup("avante_llm", { clear = true })
|
||||
---@param opts GeneratePromptsOptions
|
||||
---@return AvantePromptOptions
|
||||
M.generate_prompts = function(opts)
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
local mode = opts.mode or "planning"
|
||||
---@type AvanteProviderFunctor | AvanteBedrockProviderFunctor
|
||||
local _, request_body = P.parse_config(Provider)
|
||||
local _, request_body = Providers.parse_config(provider)
|
||||
local max_tokens = request_body.max_tokens or 4096
|
||||
|
||||
-- Check if the instructions contains an image path
|
||||
local image_paths = {}
|
||||
local instructions = opts.instructions
|
||||
if opts.instructions:match("image: ") then
|
||||
if instructions and instructions:match("image: ") then
|
||||
local lines = vim.split(opts.instructions, "\n")
|
||||
for i, line in ipairs(lines) do
|
||||
if line:match("^image: ") then
|
||||
@@ -49,7 +49,7 @@ M.generate_prompts = function(opts)
|
||||
local system_info = Utils.get_system_info()
|
||||
|
||||
local template_opts = {
|
||||
use_xml_format = Provider.use_xml_format,
|
||||
use_xml_format = provider.use_xml_format,
|
||||
ask = opts.ask, -- TODO: add mode without ask instruction
|
||||
code_lang = opts.code_lang,
|
||||
selected_files = opts.selected_files,
|
||||
@@ -57,6 +57,7 @@ M.generate_prompts = function(opts)
|
||||
project_context = opts.project_context,
|
||||
diagnostics = opts.diagnostics,
|
||||
system_info = system_info,
|
||||
model_name = provider.model or "unknown",
|
||||
}
|
||||
|
||||
local system_prompt = Path.prompts.render_mode(mode, template_opts)
|
||||
@@ -74,15 +75,17 @@ M.generate_prompts = function(opts)
|
||||
if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
|
||||
end
|
||||
|
||||
if #opts.selected_files > 0 or opts.selected_code ~= nil then
|
||||
if (opts.selected_files and #opts.selected_files > 0 or false) or opts.selected_code ~= nil then
|
||||
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
|
||||
if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
|
||||
end
|
||||
|
||||
if opts.use_xml_format then
|
||||
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||
else
|
||||
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
|
||||
if instructions then
|
||||
if opts.use_xml_format then
|
||||
table.insert(messages, { role = "user", content = string.format("<question>%s</question>", instructions) })
|
||||
else
|
||||
table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
|
||||
end
|
||||
end
|
||||
|
||||
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
|
||||
@@ -110,6 +113,22 @@ M.generate_prompts = function(opts)
|
||||
if #messages > 0 and messages[1].role == "assistant" then table.remove(messages, 1) end
|
||||
end
|
||||
|
||||
if opts.mode == "cursor-applying" then
|
||||
local user_prompt = [[
|
||||
Merge all changes from the <update> snippet into the <code> below.
|
||||
- Preserve the code's structure, order, comments, and indentation exactly.
|
||||
- Output only the updated code, enclosed within <updated-code> and </updated-code> tags.
|
||||
- Do not include any additional text, explanations, placeholders, ellipses, or code fences.
|
||||
|
||||
]]
|
||||
user_prompt = user_prompt .. string.format("<code>\n%s\n</code>\n", opts.original_code)
|
||||
for _, snippet in ipairs(opts.update_snippets) do
|
||||
user_prompt = user_prompt .. string.format("<update>\n%s\n</update>\n", snippet)
|
||||
end
|
||||
user_prompt = user_prompt .. "Provide the complete updated code."
|
||||
table.insert(messages, { role = "user", content = user_prompt })
|
||||
end
|
||||
|
||||
---@type AvantePromptOptions
|
||||
return {
|
||||
system_prompt = system_prompt,
|
||||
@@ -133,7 +152,7 @@ end
|
||||
|
||||
---@param opts StreamOptions
|
||||
M._stream = function(opts)
|
||||
local Provider = opts.provider or P[Config.provider]
|
||||
local provider = opts.provider or Providers[Config.provider]
|
||||
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
|
||||
@@ -166,7 +185,7 @@ M._stream = function(opts)
|
||||
}
|
||||
|
||||
---@type AvanteCurlOutput
|
||||
local spec = Provider.parse_curl_args(Provider, prompt_opts)
|
||||
local spec = provider.parse_curl_args(provider, prompt_opts)
|
||||
|
||||
local resp_ctx = {}
|
||||
|
||||
@@ -178,11 +197,11 @@ M._stream = function(opts)
|
||||
return
|
||||
end
|
||||
local data_match = line:match("^data: (.+)$")
|
||||
if data_match then Provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
|
||||
if data_match then provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
|
||||
end
|
||||
|
||||
local function parse_response_without_stream(data)
|
||||
Provider.parse_response_without_stream(data, current_event_state, handler_opts)
|
||||
provider.parse_response_without_stream(data, current_event_state, handler_opts)
|
||||
end
|
||||
|
||||
local completed = false
|
||||
@@ -214,17 +233,17 @@ M._stream = function(opts)
|
||||
end
|
||||
if not data then return end
|
||||
vim.schedule(function()
|
||||
if Config[Config.provider] == nil and Provider.parse_stream_data ~= nil then
|
||||
if Provider.parse_response ~= nil then
|
||||
if Config[Config.provider] == nil and provider.parse_stream_data ~= nil then
|
||||
if provider.parse_response ~= nil then
|
||||
Utils.warn(
|
||||
"parse_stream_data and parse_response are mutually exclusive, and thus parse_response will be ignored. Make sure that you handle the incoming data correctly.",
|
||||
{ once = true }
|
||||
)
|
||||
end
|
||||
Provider.parse_stream_data(data, handler_opts)
|
||||
provider.parse_stream_data(data, handler_opts)
|
||||
else
|
||||
if Provider.parse_stream_data ~= nil then
|
||||
Provider.parse_stream_data(data, handler_opts)
|
||||
if provider.parse_stream_data ~= nil then
|
||||
provider.parse_stream_data(data, handler_opts)
|
||||
else
|
||||
parse_stream_data(data)
|
||||
end
|
||||
@@ -259,8 +278,8 @@ M._stream = function(opts)
|
||||
active_job = nil
|
||||
cleanup()
|
||||
if result.status >= 400 then
|
||||
if Provider.on_error then
|
||||
Provider.on_error(result)
|
||||
if provider.on_error then
|
||||
provider.on_error(result)
|
||||
else
|
||||
Utils.error("API request failed with status " .. result.status, { once = true, title = "Avante" })
|
||||
end
|
||||
@@ -388,7 +407,7 @@ M._dual_boost_stream = function(opts, Provider1, Provider2)
|
||||
if not success then Utils.error("Failed to start dual_boost streams: " .. tostring(err)) end
|
||||
end
|
||||
|
||||
---@alias LlmMode "planning" | "editing" | "suggesting"
|
||||
---@alias LlmMode "planning" | "editing" | "suggesting" | "cursor-planning" | "cursor-applying"
|
||||
---
|
||||
---@class SelectedFiles
|
||||
---@field path string
|
||||
@@ -408,11 +427,13 @@ end
|
||||
---
|
||||
---@class GeneratePromptsOptions: TemplateOptions
|
||||
---@field ask boolean
|
||||
---@field instructions string
|
||||
---@field instructions? string
|
||||
---@field mode LlmMode
|
||||
---@field provider AvanteProviderFunctor | AvanteBedrockProviderFunctor | nil
|
||||
---@field tools? AvanteLLMTool[]
|
||||
---@field tool_histories? AvanteLLMToolHistory[]
|
||||
---@field original_code? string
|
||||
---@field update_snippets? string[]
|
||||
---
|
||||
---@class AvanteLLMToolHistory
|
||||
---@field tool_result? AvanteLLMToolResult
|
||||
@@ -450,7 +471,11 @@ M.stream = function(opts)
|
||||
end)
|
||||
end
|
||||
if Config.dual_boost.enabled and opts.mode == "planning" then
|
||||
M._dual_boost_stream(opts, P[Config.dual_boost.first_provider], P[Config.dual_boost.second_provider])
|
||||
M._dual_boost_stream(
|
||||
opts,
|
||||
Providers[Config.dual_boost.first_provider],
|
||||
Providers[Config.dual_boost.second_provider]
|
||||
)
|
||||
else
|
||||
M._stream(opts)
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user