diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index f0498fe..fa3d448 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -50,6 +50,7 @@ end function M:is_disable_stream() return false end +---@return AvanteClaudeMessage[] function M:parse_messages(opts) ---@type AvanteClaudeMessage[] local messages = {} @@ -64,16 +65,8 @@ function M:parse_messages(opts) table.sort(messages_with_length, function(a, b) return a.length > b.length end) - ---@type table - local top_two = {} - if self.support_prompt_caching then - for i = 1, math.min(2, #messages_with_length) do - top_two[messages_with_length[i].idx] = true - end - end - local has_tool_use = false - for idx, message in ipairs(opts.messages) do + for _, message in ipairs(opts.messages) do local content_items = message.content local message_content = {} if type(content_items) == "string" then @@ -82,7 +75,6 @@ function M:parse_messages(opts) table.insert(message_content, { type = "text", text = content_items, - cache_control = top_two[idx] and { type = "ephemeral" } or nil, }) end elseif type(content_items) == "table" then @@ -90,15 +82,9 @@ function M:parse_messages(opts) for _, item in ipairs(content_items) do if type(item) == "string" then if message.role == "assistant" then item = item:gsub("%s+$", "") end - table.insert( - message_content, - { type = "text", text = item, cache_control = top_two[idx] and { type = "ephemeral" } or nil } - ) + table.insert(message_content, { type = "text", text = item }) elseif type(item) == "table" and item.type == "text" then - table.insert( - message_content, - { type = "text", text = item.text, cache_control = top_two[idx] and { type = "ephemeral" } or nil } - ) + table.insert(message_content, { type = "text", text = item.text }) elseif type(item) == "table" and item.type == "image" then table.insert(message_content, { type = "image", source = item.source }) elseif not provider_conf.disable_tools and type(item) == "table" and item.type == "tool_use" then @@ -386,10 +372,34 @@ function M:parse_curl_args(prompt_opts) end end - if self.support_prompt_caching and #tools > 0 then - local last_tool = vim.deepcopy(tools[#tools]) - last_tool.cache_control = { type = "ephemeral" } - tools[#tools] = last_tool + if self.support_prompt_caching then + if #messages > 0 then + local found = false + for i = #messages, 1, -1 do + local message = messages[i] + message = vim.deepcopy(message) + ---@cast message AvanteClaudeMessage + local content = message.content + ---@cast content AvanteClaudeMessageContentTextItem[] + for j = #content, 1, -1 do + local item = content[j] + if item.type == "text" then + item.cache_control = { type = "ephemeral" } + found = true + break + end + end + if found then + messages[i] = message + break + end + end + end + if #tools > 0 then + local last_tool = vim.deepcopy(tools[#tools]) + last_tool.cache_control = { type = "ephemeral" } + tools[#tools] = last_tool + end end return { @@ -403,7 +413,7 @@ function M:parse_curl_args(prompt_opts) { type = "text", text = prompt_opts.system_prompt, - cache_control = { type = "ephemeral" }, + cache_control = self.support_prompt_caching and { type = "ephemeral" } or nil, }, }, messages = messages, diff --git a/lua/avante/types.lua b/lua/avante/types.lua index dd82757..7f975be 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -78,7 +78,7 @@ vim.g.avante_login = vim.g.avante_login ---@field on_messages_add? fun(messages: avante.HistoryMessage[]): nil ---@field on_state_change? fun(state: avante.GenerateState): nil --- ----@alias AvanteLLMMessageContentItem string | { type: "text", text: string } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } +---@alias AvanteLLMMessageContentItem string | { type: "text", text: string, cache_control: { type: string } | nil } | { type: "image", source: { type: "base64", media_type: string, data: string } } | { type: "tool_use", name: string, id: string, input: any } | { type: "tool_result", tool_use_id: string, content: string, is_error?: boolean } | { type: "thinking", thinking: string, signature: string } | { type: "redacted_thinking", data: string } ---@alias AvanteLLMMessageContent AvanteLLMMessageContentItem[] | string @@ -129,20 +129,20 @@ vim.g.avante_login = vim.g.avante_login ---@field role "user" ---@field parts { text: string }[] --- ----@class AvanteClaudeBaseMessage +---@class AvanteClaudeMessageContentBaseItem ---@field cache_control {type: "ephemeral"}? --- ----@class AvanteClaudeTextMessage: AvanteClaudeBaseMessage +---@class AvanteClaudeMessageContentTextItem: AvanteClaudeMessageContentBaseItem ---@field type "text" ---@field text string --- ----@class AvanteClaudeImageMessage: AvanteClaudeBaseMessage +---@class AvanteClaudeMessageCotnentImageItem: AvanteClaudeMessageContentBaseItem ---@field type "image" ---@field source {type: "base64", media_type: string, data: string} --- ---@class AvanteClaudeMessage ---@field role "user" | "assistant" ----@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][] +---@field content [AvanteClaudeMessageContentTextItem | AvanteClaudeMessageCotnentImageItem][] ---@class AvanteClaudeTool ---@field name string