diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index d16883a..2594a6d 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -67,43 +67,31 @@ function M.generate_prompts(opts) if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then local project_context = Path.prompts.render_file("_project.avanterules", template_opts) - if project_context ~= "" then - table.insert(messages, { role = "user", content = { { type = "text", text = project_context } } }) - end + if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end end if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts) - if diagnostics ~= "" then - table.insert(messages, { role = "user", content = { { type = "text", text = diagnostics } } }) - end + if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end end if (opts.selected_files and #opts.selected_files > 0 or false) or opts.selected_code ~= nil then local code_context = Path.prompts.render_file("_context.avanterules", template_opts) - if code_context ~= "" then - table.insert(messages, { role = "user", content = { { type = "text", text = code_context } } }) - end + if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end end if instructions then if opts.use_xml_format then - table.insert(messages, { - role = "user", - content = { { type = "text", text = string.format("%s", instructions) } }, - }) + table.insert(messages, { role = "user", content = string.format("%s", instructions) }) else - table.insert( - messages, - { role = "user", content = { { type = "text", text = string.format("QUESTION:\n%s", instructions) } } } - ) + table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) }) end end local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt) for _, message in ipairs(messages) do - remaining_tokens = remaining_tokens - Utils.tokens.calculate_message_content_tokens(message.content) + remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content) end if opts.history_messages then @@ -112,7 +100,7 @@ function M.generate_prompts(opts) local history_messages = {} for i = #opts.history_messages, 1, -1 do local message = opts.history_messages[i] - local tokens = Utils.tokens.calculate_message_content_tokens(message.content) + local tokens = Utils.tokens.calculate_tokens(message.content) remaining_tokens = remaining_tokens - tokens if remaining_tokens > 0 then table.insert(history_messages, message) @@ -138,7 +126,7 @@ Merge all changes from the snippet into the below. user_prompt = user_prompt .. string.format("\n%s\n\n", snippet) end user_prompt = user_prompt .. "Provide the complete updated code." - table.insert(messages, { role = "user", content = { { type = "text", text = user_prompt } } }) + table.insert(messages, { role = "user", content = user_prompt }) end ---@type AvantePromptOptions @@ -157,7 +145,7 @@ function M.calculate_tokens(opts) local prompt_opts = M.generate_prompts(opts) local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt) for _, message in ipairs(prompt_opts.messages) do - tokens = tokens + Utils.tokens.calculate_message_content_tokens(message.content) + tokens = tokens + Utils.tokens.calculate_tokens(message.content) end return tokens end diff --git a/lua/avante/path.lua b/lua/avante/path.lua index 30cc5cd..e404a45 100644 --- a/lua/avante/path.lua +++ b/lua/avante/path.lua @@ -5,6 +5,18 @@ local Path = require("plenary.path") local Scan = require("plenary.scandir") local Config = require("avante.config") +---@class avante.ChatHistoryEntry +---@field timestamp string +---@field provider string +---@field model string +---@field request string +---@field response string +---@field original_response string +---@field selected_file {filepath: string}? +---@field selected_code {filetype: string, content: string}? +---@field reset_memory boolean? +---@field selected_filepaths string[] | nil + ---@class avante.Path ---@field history_path Path ---@field cache_path Path diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua index d9bf2ee..d018780 100644 --- a/lua/avante/providers/bedrock/claude.lua +++ b/lua/avante/providers/bedrock/claude.lua @@ -16,7 +16,74 @@ M.role_map = { assistant = "assistant", } -M.parse_messages = Claude.parse_messages +function M.parse_messages(opts) + ---@type AvanteBedrockClaudeMessage[] + local messages = {} + + for _, message in ipairs(opts.messages) do + table.insert(messages, { + role = M.role_map[message.role], + content = { + { + type = "text", + text = message.content, + }, + }, + }) + end + + if opts.tool_histories then + for _, tool_history in ipairs(opts.tool_histories) do + if tool_history.tool_use then + local msg = { + role = "assistant", + content = {}, + } + if tool_history.tool_use.thinking_contents then + for _, thinking_content in ipairs(tool_history.tool_use.thinking_contents) do + msg.content[#msg.content + 1] = { + type = "thinking", + thinking = thinking_content.content, + signature = thinking_content.signature, + } + end + end + if tool_history.tool_use.response_contents then + for _, response_content in ipairs(tool_history.tool_use.response_contents) do + msg.content[#msg.content + 1] = { + type = "text", + text = response_content, + } + end + end + msg.content[#msg.content + 1] = { + type = "tool_use", + id = tool_history.tool_use.id, + name = tool_history.tool_use.name, + input = vim.json.decode(tool_history.tool_use.input_json), + } + messages[#messages + 1] = msg + end + + if tool_history.tool_result then + messages[#messages + 1] = { + role = "user", + content = { + { + type = "tool_result", + tool_use_id = tool_history.tool_result.tool_use_id, + content = tool_history.tool_result.content, + is_error = tool_history.tool_result.is_error, + }, + }, + } + end + end + end + + return messages +end + M.parse_response = Claude.parse_response ---@param prompt_opts AvantePromptOptions diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index a16622a..5c0fcbe 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -43,10 +43,7 @@ function M.parse_messages(opts) ---@type {idx: integer, length: integer}[] local messages_with_length = {} for idx, message in ipairs(opts.messages) do - table.insert( - messages_with_length, - { idx = idx, length = Utils.tokens.calculate_message_content_tokens(message.content) } - ) + table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) }) end table.sort(messages_with_length, function(a, b) return a.length > b.length end) @@ -58,46 +55,15 @@ function M.parse_messages(opts) end for idx, message in ipairs(opts.messages) do - local content = {} - if type(message.content) == "string" then - table.insert(content, { - type = "text", - text = message.content, - cache_control = top_three[idx] and { type = "ephemeral" } or nil, - }) - else - local message_content = message.content - ---@cast message_content AvanteLLMMessageContentItem[] - for _, item in ipairs(message_content) do - if type(item) == "string" then - table.insert(content, { - type = "text", - text = item, - cache_control = top_three[idx] and { type = "ephemeral" } or nil, - }) - elseif item.type == "text" then - table.insert(content, { - type = "text", - text = item.text, - cache_control = top_three[idx] and { type = "ephemeral" } or nil, - }) - elseif item.type == "thinking" then - table.insert(content, { - type = "thinking", - thinking = item.thinking, - signature = item.signature, - }) - elseif item.type == "redacted_thinking" then - table.insert(content, { - type = "redacted_thinking", - data = item.data, - }) - end - end - end table.insert(messages, { role = M.role_map[message.role], - content = content, + content = { + { + type = "text", + text = message.content, + cache_control = top_three[idx] and { type = "ephemeral" } or nil, + }, + }, }) end @@ -123,20 +89,12 @@ function M.parse_messages(opts) role = "assistant", content = {}, } - if tool_history.tool_use.thinking_blocks then - for _, thinking_block in ipairs(tool_history.tool_use.thinking_blocks) do + if tool_history.tool_use.thinking_contents then + for _, thinking_content in ipairs(tool_history.tool_use.thinking_contents) do msg.content[#msg.content + 1] = { type = "thinking", - thinking = thinking_block.thinking, - signature = thinking_block.signature, - } - end - end - if tool_history.tool_use.redacted_thinking_blocks then - for _, thinking_block in ipairs(tool_history.tool_use.redacted_thinking_blocks) do - msg.content[#msg.content + 1] = { - type = "redacted_thinking", - data = thinking_block.data, + thinking = thinking_content.content, + signature = thinking_content.signature, } end end @@ -177,8 +135,6 @@ function M.parse_messages(opts) end function M.parse_response(ctx, data_stream, event_state, opts) - if ctx.resp_filename == nil then ctx.resp_filename = vim.fn.tempname() .. ".txt" end - vim.fn.writefile({ data_stream }, ctx.resp_filename, "a") if event_state == nil then if data_stream:match('"message_start"') then event_state = "message_start" @@ -252,33 +208,24 @@ function M.parse_response(ctx, data_stream, event_state, opts) :filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end) :map(function(content_block_) return content_block_.text end) :totable() - local thinking_blocks = vim + local thinking_contents = vim .iter(ctx.content_blocks) :filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end) :map( function(content_block_) - return { thinking = content_block_.thinking, signature = content_block_.signature } + return { content = content_block_.thinking, signature = content_block_.signature } end ) :totable() - local redacted_thinking_blocks = vim - .iter(ctx.content_blocks) - :filter( - function(content_block_) return content_block_.stoppped and content_block_.type == "redacted_thinking" end - ) - :map(function(content_block_) return { data = content_block_.data } end) - :totable() return { name = content_block.name, id = content_block.id, input_json = content_block.input_json, response_contents = response_contents, - thinking_blocks = thinking_blocks, - redacted_thinking_blocks = redacted_thinking_blocks, + thinking_contents = thinking_contents, } end) :totable() - Utils.debug("resp filename", ctx.resp_filename) opts.on_stop({ reason = "tool_use", usage = jsn.usage, diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index 5b679e5..a891e4e 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -33,15 +33,9 @@ function M.parse_messages(opts) end end prev_role = role - local parts = {} - for _, item in ipairs(message.content) do - if type(item) == "string" then - table.insert(parts, { text = item }) - elseif item.type == "text" then - table.insert(parts, { text = item.text }) - end - end - table.insert(contents, { role = M.role_map[role] or role, parts = parts }) + table.insert(contents, { role = M.role_map[role] or role, parts = { + { text = message.content }, + } }) end) if Clipboard.support_paste_image() and opts.image_paths then diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index c6f3fd9..9eff43a 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -71,30 +71,18 @@ function M.parse_messages(opts) -- NOTE: Handle the case where the selected model is the `o1` model -- "o1" models are "smart" enough to understand user prompt as a system prompt in this context if M.is_o_series_model(base.model) then - table.insert(messages, { role = "user", content = { { type = "text", text = opts.system_prompt } } }) + table.insert(messages, { role = "user", content = opts.system_prompt }) else table.insert(messages, { role = "system", content = opts.system_prompt }) end - vim.iter(opts.messages):each(function(msg) - if type(msg.content) == "string" then - table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) - else - local content = {} - for _, item in ipairs(msg.content) do - if type(item) == "string" then - table.insert(content, { type = "text", text = item }) - elseif item.type == "text" then - table.insert(content, { type = "text", text = item.text }) - end - end - table.insert(messages, { role = M.role_map[msg.role], content = content }) - end - end) + vim + .iter(opts.messages) + :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end) if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then local message_content = messages[#messages].content - if type(message_content) == "string" then message_content = { { type = "text", text = message_content } } end + if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end for _, image_path in ipairs(opts.image_paths) do table.insert(message_content, { type = "image_url", @@ -115,7 +103,7 @@ function M.parse_messages(opts) if role == M.role_map["user"] then table.insert(final_messages, { role = M.role_map["assistant"], content = "Ok, I understand." }) else - table.insert(final_messages, { role = M.role_map["user"], content = { { type = "text", text = "Ok" } } }) + table.insert(final_messages, { role = M.role_map["user"], content = "Ok" }) end end prev_role = role diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 5de6e0c..c373f62 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -392,7 +392,6 @@ local function transform_result_content(selected_files, result_content, prev_fil elseif line_content == "" then is_thinking = true last_think_tag_start_line = i - last_think_tag_end_line = 0 elseif line_content == "" then is_thinking = false last_think_tag_end_line = i @@ -2357,7 +2356,6 @@ function Sidebar:create_input_container(opts) end end - ---@type AvanteLLMMessage[] local history_messages = {} for i = #chat_history, 1, -1 do local entry = chat_history[i] @@ -2370,9 +2368,11 @@ function Sidebar:create_input_container(opts) then break end - local assistant_content = {} - table.insert(assistant_content, { type = "text", text = Utils.trim_think_content(entry.original_response) }) - table.insert(history_messages, 1, { role = "assistant", content = assistant_content }) + table.insert( + history_messages, + 1, + { role = "assistant", content = Utils.trim_think_content(entry.original_response) } + ) local user_content = "" if entry.selected_file ~= nil then user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n" @@ -2386,18 +2386,9 @@ function Sidebar:create_input_container(opts) .. "\n```\n\n" end user_content = user_content .. "USER PROMPT:\n\n" .. entry.request - table.insert(history_messages, 1, { - role = "user", - content = { - { - type = "text", - text = user_content, - }, - }, - }) + table.insert(history_messages, 1, { role = "user", content = user_content }) end - ---@type AvanteGeneratePromptsOptions return { ask = opts.ask or true, project_context = vim.json.encode(project_context), diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index 5b547f8..aa3f2d4 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -78,10 +78,7 @@ function Suggestion:suggest() local history_messages = { { role = "user", - content = { - { - type = "text", - text = [[ + content = [[ a.py L1: def fib @@ -90,9 +87,7 @@ L3: if __name__ == "__main__": L4: # just pass L5: pass -]], - }, - }, + ]], }, { role = "assistant", @@ -100,12 +95,7 @@ L5: pass }, { role = "user", - content = { - { - type = "text", - text = '{"insertSpaces":true,"tabSize":4,"indentSize":4,"position":{"row":1,"col":7}}', - }, - }, + content = '{"insertSpaces":true,"tabSize":4,"indentSize":4,"position":{"row":1,"col":7}}', }, { role = "assistant", diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 7ed0cd1..fc214b4 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -76,13 +76,9 @@ vim.g.avante_login = vim.g.avante_login ---@field on_chunk AvanteLLMChunkCallback ---@field on_stop AvanteLLMStopCallback --- ----@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } ---- ----@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string ---- ---@class AvanteLLMMessage ---@field role "user" | "assistant" ----@field content AvanteLLMMessageContent +---@field content string --- ---@class AvanteLLMToolResult ---@field tool_name string @@ -231,19 +227,11 @@ vim.g.avante_login = vim.g.avante_login ---@field id string ---@field input_json string ---@field response_contents? string[] ----@field thinking_blocks? AvanteLLMThinkingBlock[] ----@field redacted_thinking_blocks? AvanteLLMRedactedThinkingBlock[] +---@field thinking_contents? { content: string, signature: string }[] --- ---@class AvanteLLMStartCallbackOptions ---@field usage? AvanteLLMUsage --- ----@class AvanteLLMThinkingBlock ----@field thinking string ----@field signature string ---- ----@class AvanteLLMRedactedThinkingBlock ----@field data string ---- ---@class AvanteLLMStopCallbackOptions ---@field reason "complete" | "tool_use" | "error" | "rate_limit" ---@field error? string | table @@ -355,15 +343,3 @@ vim.g.avante_login = vim.g.avante_login ---@field description string ---@field type string ---@field optional? boolean - ----@class avante.ChatHistoryEntry ----@field timestamp string ----@field provider string ----@field model string ----@field request string ----@field response string ----@field original_response string ----@field selected_file {filepath: string}? ----@field selected_code {filetype: string, content: string}? ----@field reset_memory boolean? ----@field selected_filepaths string[] | nil diff --git a/lua/avante/utils/tokens.lua b/lua/avante/utils/tokens.lua index 1193f99..4f1382d 100644 --- a/lua/avante/utils/tokens.lua +++ b/lua/avante/utils/tokens.lua @@ -9,21 +9,6 @@ local cost_per_token = { davinci = 0.000002, } ----@param content AvanteLLMMessageContent ----@return integer -function Tokens.calculate_message_content_tokens(content) - if type(content) == "string" then return Tokens.calculate_tokens(content) end - local tokens = 0 - for _, item in ipairs(content) do - if type(item) == "string" then - tokens = tokens + Tokens.calculate_tokens(item) - elseif item.type == "text" then - tokens = tokens + Tokens.calculate_tokens(item.text) - end - end - return tokens -end - --- Calculate the number of tokens in a given text. ---@param text string The text to calculate the number of tokens for. ---@return integer The number of tokens in the given text.