diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua
index 2594a6d..d16883a 100644
--- a/lua/avante/llm.lua
+++ b/lua/avante/llm.lua
@@ -67,31 +67,43 @@ function M.generate_prompts(opts)
if opts.project_context ~= nil and opts.project_context ~= "" and opts.project_context ~= "null" then
local project_context = Path.prompts.render_file("_project.avanterules", template_opts)
- if project_context ~= "" then table.insert(messages, { role = "user", content = project_context }) end
+ if project_context ~= "" then
+ table.insert(messages, { role = "user", content = { { type = "text", text = project_context } } })
+ end
end
if opts.diagnostics ~= nil and opts.diagnostics ~= "" and opts.diagnostics ~= "null" then
local diagnostics = Path.prompts.render_file("_diagnostics.avanterules", template_opts)
- if diagnostics ~= "" then table.insert(messages, { role = "user", content = diagnostics }) end
+ if diagnostics ~= "" then
+ table.insert(messages, { role = "user", content = { { type = "text", text = diagnostics } } })
+ end
end
if (opts.selected_files and #opts.selected_files > 0 or false) or opts.selected_code ~= nil then
local code_context = Path.prompts.render_file("_context.avanterules", template_opts)
- if code_context ~= "" then table.insert(messages, { role = "user", content = code_context }) end
+ if code_context ~= "" then
+ table.insert(messages, { role = "user", content = { { type = "text", text = code_context } } })
+ end
end
if instructions then
if opts.use_xml_format then
- table.insert(messages, { role = "user", content = string.format("%s", instructions) })
+ table.insert(messages, {
+ role = "user",
+ content = { { type = "text", text = string.format("%s", instructions) } },
+ })
else
- table.insert(messages, { role = "user", content = string.format("QUESTION:\n%s", instructions) })
+ table.insert(
+ messages,
+ { role = "user", content = { { type = "text", text = string.format("QUESTION:\n%s", instructions) } } }
+ )
end
end
local remaining_tokens = max_tokens - Utils.tokens.calculate_tokens(system_prompt)
for _, message in ipairs(messages) do
- remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
+ remaining_tokens = remaining_tokens - Utils.tokens.calculate_message_content_tokens(message.content)
end
if opts.history_messages then
@@ -100,7 +112,7 @@ function M.generate_prompts(opts)
local history_messages = {}
for i = #opts.history_messages, 1, -1 do
local message = opts.history_messages[i]
- local tokens = Utils.tokens.calculate_tokens(message.content)
+ local tokens = Utils.tokens.calculate_message_content_tokens(message.content)
remaining_tokens = remaining_tokens - tokens
if remaining_tokens > 0 then
table.insert(history_messages, message)
@@ -126,7 +138,7 @@ Merge all changes from the snippet into the below.
user_prompt = user_prompt .. string.format("\n%s\n\n", snippet)
end
user_prompt = user_prompt .. "Provide the complete updated code."
- table.insert(messages, { role = "user", content = user_prompt })
+ table.insert(messages, { role = "user", content = { { type = "text", text = user_prompt } } })
end
---@type AvantePromptOptions
@@ -145,7 +157,7 @@ function M.calculate_tokens(opts)
local prompt_opts = M.generate_prompts(opts)
local tokens = Utils.tokens.calculate_tokens(prompt_opts.system_prompt)
for _, message in ipairs(prompt_opts.messages) do
- tokens = tokens + Utils.tokens.calculate_tokens(message.content)
+ tokens = tokens + Utils.tokens.calculate_message_content_tokens(message.content)
end
return tokens
end
diff --git a/lua/avante/path.lua b/lua/avante/path.lua
index e404a45..30cc5cd 100644
--- a/lua/avante/path.lua
+++ b/lua/avante/path.lua
@@ -5,18 +5,6 @@ local Path = require("plenary.path")
local Scan = require("plenary.scandir")
local Config = require("avante.config")
----@class avante.ChatHistoryEntry
----@field timestamp string
----@field provider string
----@field model string
----@field request string
----@field response string
----@field original_response string
----@field selected_file {filepath: string}?
----@field selected_code {filetype: string, content: string}?
----@field reset_memory boolean?
----@field selected_filepaths string[] | nil
-
---@class avante.Path
---@field history_path Path
---@field cache_path Path
diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua
index d018780..d9bf2ee 100644
--- a/lua/avante/providers/bedrock/claude.lua
+++ b/lua/avante/providers/bedrock/claude.lua
@@ -16,74 +16,7 @@ M.role_map = {
assistant = "assistant",
}
-function M.parse_messages(opts)
- ---@type AvanteBedrockClaudeMessage[]
- local messages = {}
-
- for _, message in ipairs(opts.messages) do
- table.insert(messages, {
- role = M.role_map[message.role],
- content = {
- {
- type = "text",
- text = message.content,
- },
- },
- })
- end
-
- if opts.tool_histories then
- for _, tool_history in ipairs(opts.tool_histories) do
- if tool_history.tool_use then
- local msg = {
- role = "assistant",
- content = {},
- }
- if tool_history.tool_use.thinking_contents then
- for _, thinking_content in ipairs(tool_history.tool_use.thinking_contents) do
- msg.content[#msg.content + 1] = {
- type = "thinking",
- thinking = thinking_content.content,
- signature = thinking_content.signature,
- }
- end
- end
- if tool_history.tool_use.response_contents then
- for _, response_content in ipairs(tool_history.tool_use.response_contents) do
- msg.content[#msg.content + 1] = {
- type = "text",
- text = response_content,
- }
- end
- end
- msg.content[#msg.content + 1] = {
- type = "tool_use",
- id = tool_history.tool_use.id,
- name = tool_history.tool_use.name,
- input = vim.json.decode(tool_history.tool_use.input_json),
- }
- messages[#messages + 1] = msg
- end
-
- if tool_history.tool_result then
- messages[#messages + 1] = {
- role = "user",
- content = {
- {
- type = "tool_result",
- tool_use_id = tool_history.tool_result.tool_use_id,
- content = tool_history.tool_result.content,
- is_error = tool_history.tool_result.is_error,
- },
- },
- }
- end
- end
- end
-
- return messages
-end
-
+M.parse_messages = Claude.parse_messages
M.parse_response = Claude.parse_response
---@param prompt_opts AvantePromptOptions
diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua
index 5c0fcbe..a16622a 100644
--- a/lua/avante/providers/claude.lua
+++ b/lua/avante/providers/claude.lua
@@ -43,7 +43,10 @@ function M.parse_messages(opts)
---@type {idx: integer, length: integer}[]
local messages_with_length = {}
for idx, message in ipairs(opts.messages) do
- table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) })
+ table.insert(
+ messages_with_length,
+ { idx = idx, length = Utils.tokens.calculate_message_content_tokens(message.content) }
+ )
end
table.sort(messages_with_length, function(a, b) return a.length > b.length end)
@@ -55,15 +58,46 @@ function M.parse_messages(opts)
end
for idx, message in ipairs(opts.messages) do
+ local content = {}
+ if type(message.content) == "string" then
+ table.insert(content, {
+ type = "text",
+ text = message.content,
+ cache_control = top_three[idx] and { type = "ephemeral" } or nil,
+ })
+ else
+ local message_content = message.content
+ ---@cast message_content AvanteLLMMessageContentItem[]
+ for _, item in ipairs(message_content) do
+ if type(item) == "string" then
+ table.insert(content, {
+ type = "text",
+ text = item,
+ cache_control = top_three[idx] and { type = "ephemeral" } or nil,
+ })
+ elseif item.type == "text" then
+ table.insert(content, {
+ type = "text",
+ text = item.text,
+ cache_control = top_three[idx] and { type = "ephemeral" } or nil,
+ })
+ elseif item.type == "thinking" then
+ table.insert(content, {
+ type = "thinking",
+ thinking = item.thinking,
+ signature = item.signature,
+ })
+ elseif item.type == "redacted_thinking" then
+ table.insert(content, {
+ type = "redacted_thinking",
+ data = item.data,
+ })
+ end
+ end
+ end
table.insert(messages, {
role = M.role_map[message.role],
- content = {
- {
- type = "text",
- text = message.content,
- cache_control = top_three[idx] and { type = "ephemeral" } or nil,
- },
- },
+ content = content,
})
end
@@ -89,12 +123,20 @@ function M.parse_messages(opts)
role = "assistant",
content = {},
}
- if tool_history.tool_use.thinking_contents then
- for _, thinking_content in ipairs(tool_history.tool_use.thinking_contents) do
+ if tool_history.tool_use.thinking_blocks then
+ for _, thinking_block in ipairs(tool_history.tool_use.thinking_blocks) do
msg.content[#msg.content + 1] = {
type = "thinking",
- thinking = thinking_content.content,
- signature = thinking_content.signature,
+ thinking = thinking_block.thinking,
+ signature = thinking_block.signature,
+ }
+ end
+ end
+ if tool_history.tool_use.redacted_thinking_blocks then
+ for _, thinking_block in ipairs(tool_history.tool_use.redacted_thinking_blocks) do
+ msg.content[#msg.content + 1] = {
+ type = "redacted_thinking",
+ data = thinking_block.data,
}
end
end
@@ -135,6 +177,8 @@ function M.parse_messages(opts)
end
function M.parse_response(ctx, data_stream, event_state, opts)
+ if ctx.resp_filename == nil then ctx.resp_filename = vim.fn.tempname() .. ".txt" end
+ vim.fn.writefile({ data_stream }, ctx.resp_filename, "a")
if event_state == nil then
if data_stream:match('"message_start"') then
event_state = "message_start"
@@ -208,24 +252,33 @@ function M.parse_response(ctx, data_stream, event_state, opts)
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end)
:map(function(content_block_) return content_block_.text end)
:totable()
- local thinking_contents = vim
+ local thinking_blocks = vim
.iter(ctx.content_blocks)
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end)
:map(
function(content_block_)
- return { content = content_block_.thinking, signature = content_block_.signature }
+ return { thinking = content_block_.thinking, signature = content_block_.signature }
end
)
:totable()
+ local redacted_thinking_blocks = vim
+ .iter(ctx.content_blocks)
+ :filter(
+ function(content_block_) return content_block_.stoppped and content_block_.type == "redacted_thinking" end
+ )
+ :map(function(content_block_) return { data = content_block_.data } end)
+ :totable()
return {
name = content_block.name,
id = content_block.id,
input_json = content_block.input_json,
response_contents = response_contents,
- thinking_contents = thinking_contents,
+ thinking_blocks = thinking_blocks,
+ redacted_thinking_blocks = redacted_thinking_blocks,
}
end)
:totable()
+ Utils.debug("resp filename", ctx.resp_filename)
opts.on_stop({
reason = "tool_use",
usage = jsn.usage,
diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua
index a891e4e..5b679e5 100644
--- a/lua/avante/providers/gemini.lua
+++ b/lua/avante/providers/gemini.lua
@@ -33,9 +33,15 @@ function M.parse_messages(opts)
end
end
prev_role = role
- table.insert(contents, { role = M.role_map[role] or role, parts = {
- { text = message.content },
- } })
+ local parts = {}
+ for _, item in ipairs(message.content) do
+ if type(item) == "string" then
+ table.insert(parts, { text = item })
+ elseif item.type == "text" then
+ table.insert(parts, { text = item.text })
+ end
+ end
+ table.insert(contents, { role = M.role_map[role] or role, parts = parts })
end)
if Clipboard.support_paste_image() and opts.image_paths then
diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua
index 9eff43a..c6f3fd9 100644
--- a/lua/avante/providers/openai.lua
+++ b/lua/avante/providers/openai.lua
@@ -71,18 +71,30 @@ function M.parse_messages(opts)
-- NOTE: Handle the case where the selected model is the `o1` model
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
if M.is_o_series_model(base.model) then
- table.insert(messages, { role = "user", content = opts.system_prompt })
+ table.insert(messages, { role = "user", content = { { type = "text", text = opts.system_prompt } } })
else
table.insert(messages, { role = "system", content = opts.system_prompt })
end
- vim
- .iter(opts.messages)
- :each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
+ vim.iter(opts.messages):each(function(msg)
+ if type(msg.content) == "string" then
+ table.insert(messages, { role = M.role_map[msg.role], content = msg.content })
+ else
+ local content = {}
+ for _, item in ipairs(msg.content) do
+ if type(item) == "string" then
+ table.insert(content, { type = "text", text = item })
+ elseif item.type == "text" then
+ table.insert(content, { type = "text", text = item.text })
+ end
+ end
+ table.insert(messages, { role = M.role_map[msg.role], content = content })
+ end
+ end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
- if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end
+ if type(message_content) == "string" then message_content = { { type = "text", text = message_content } } end
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image_url",
@@ -103,7 +115,7 @@ function M.parse_messages(opts)
if role == M.role_map["user"] then
table.insert(final_messages, { role = M.role_map["assistant"], content = "Ok, I understand." })
else
- table.insert(final_messages, { role = M.role_map["user"], content = "Ok" })
+ table.insert(final_messages, { role = M.role_map["user"], content = { { type = "text", text = "Ok" } } })
end
end
prev_role = role
diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua
index c373f62..5de6e0c 100644
--- a/lua/avante/sidebar.lua
+++ b/lua/avante/sidebar.lua
@@ -392,6 +392,7 @@ local function transform_result_content(selected_files, result_content, prev_fil
elseif line_content == "" then
is_thinking = true
last_think_tag_start_line = i
+ last_think_tag_end_line = 0
elseif line_content == "" then
is_thinking = false
last_think_tag_end_line = i
@@ -2356,6 +2357,7 @@ function Sidebar:create_input_container(opts)
end
end
+ ---@type AvanteLLMMessage[]
local history_messages = {}
for i = #chat_history, 1, -1 do
local entry = chat_history[i]
@@ -2368,11 +2370,9 @@ function Sidebar:create_input_container(opts)
then
break
end
- table.insert(
- history_messages,
- 1,
- { role = "assistant", content = Utils.trim_think_content(entry.original_response) }
- )
+ local assistant_content = {}
+ table.insert(assistant_content, { type = "text", text = Utils.trim_think_content(entry.original_response) })
+ table.insert(history_messages, 1, { role = "assistant", content = assistant_content })
local user_content = ""
if entry.selected_file ~= nil then
user_content = user_content .. "SELECTED FILE: " .. entry.selected_file.filepath .. "\n\n"
@@ -2386,9 +2386,18 @@ function Sidebar:create_input_container(opts)
.. "\n```\n\n"
end
user_content = user_content .. "USER PROMPT:\n\n" .. entry.request
- table.insert(history_messages, 1, { role = "user", content = user_content })
+ table.insert(history_messages, 1, {
+ role = "user",
+ content = {
+ {
+ type = "text",
+ text = user_content,
+ },
+ },
+ })
end
+ ---@type AvanteGeneratePromptsOptions
return {
ask = opts.ask or true,
project_context = vim.json.encode(project_context),
diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua
index 10451f3..fee4782 100644
--- a/lua/avante/suggestion.lua
+++ b/lua/avante/suggestion.lua
@@ -78,7 +78,10 @@ function Suggestion:suggest()
local history_messages = {
{
role = "user",
- content = [[
+ content = {
+ {
+ type = "text",
+ text = [[
a.py
L1: def fib
@@ -87,7 +90,9 @@ L3: if __name__ == "__main__":
L4: # just pass
L5: pass
- ]],
+]],
+ },
+ },
},
{
role = "assistant",
@@ -95,7 +100,12 @@ L5: pass
},
{
role = "user",
- content = '{"insertSpaces":true,"tabSize":4,"indentSize":4,"position":{"row":1,"col":7}}',
+ content = {
+ {
+ type = "text",
+ text = '{"insertSpaces":true,"tabSize":4,"indentSize":4,"position":{"row":1,"col":7}}',
+ },
+ },
},
{
role = "assistant",
@@ -126,7 +136,7 @@ L5: pass
},
]
]
- ]],
+ ]],
},
}
diff --git a/lua/avante/types.lua b/lua/avante/types.lua
index fc214b4..7ed0cd1 100644
--- a/lua/avante/types.lua
+++ b/lua/avante/types.lua
@@ -76,9 +76,13 @@ vim.g.avante_login = vim.g.avante_login
---@field on_chunk AvanteLLMChunkCallback
---@field on_stop AvanteLLMStopCallback
---
+---@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string }
+---
+---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string
+---
---@class AvanteLLMMessage
---@field role "user" | "assistant"
----@field content string
+---@field content AvanteLLMMessageContent
---
---@class AvanteLLMToolResult
---@field tool_name string
@@ -227,11 +231,19 @@ vim.g.avante_login = vim.g.avante_login
---@field id string
---@field input_json string
---@field response_contents? string[]
----@field thinking_contents? { content: string, signature: string }[]
+---@field thinking_blocks? AvanteLLMThinkingBlock[]
+---@field redacted_thinking_blocks? AvanteLLMRedactedThinkingBlock[]
---
---@class AvanteLLMStartCallbackOptions
---@field usage? AvanteLLMUsage
---
+---@class AvanteLLMThinkingBlock
+---@field thinking string
+---@field signature string
+---
+---@class AvanteLLMRedactedThinkingBlock
+---@field data string
+---
---@class AvanteLLMStopCallbackOptions
---@field reason "complete" | "tool_use" | "error" | "rate_limit"
---@field error? string | table
@@ -343,3 +355,15 @@ vim.g.avante_login = vim.g.avante_login
---@field description string
---@field type string
---@field optional? boolean
+
+---@class avante.ChatHistoryEntry
+---@field timestamp string
+---@field provider string
+---@field model string
+---@field request string
+---@field response string
+---@field original_response string
+---@field selected_file {filepath: string}?
+---@field selected_code {filetype: string, content: string}?
+---@field reset_memory boolean?
+---@field selected_filepaths string[] | nil
diff --git a/lua/avante/utils/tokens.lua b/lua/avante/utils/tokens.lua
index 4f1382d..1193f99 100644
--- a/lua/avante/utils/tokens.lua
+++ b/lua/avante/utils/tokens.lua
@@ -9,6 +9,21 @@ local cost_per_token = {
davinci = 0.000002,
}
+---@param content AvanteLLMMessageContent
+---@return integer
+function Tokens.calculate_message_content_tokens(content)
+ if type(content) == "string" then return Tokens.calculate_tokens(content) end
+ local tokens = 0
+ for _, item in ipairs(content) do
+ if type(item) == "string" then
+ tokens = tokens + Tokens.calculate_tokens(item)
+ elseif item.type == "text" then
+ tokens = tokens + Tokens.calculate_tokens(item.text)
+ end
+ end
+ return tokens
+end
+
--- Calculate the number of tokens in a given text.
---@param text string The text to calculate the number of tokens for.
---@return integer The number of tokens in the given text.