perf(anthropic): prompt-caching (#517)

bring back prompt caching support on Anthropic

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham
2024-09-04 03:19:33 -04:00
committed by GitHub
parent c027ea269a
commit 2b89f0d529
12 changed files with 116 additions and 75 deletions

View File

@@ -24,13 +24,27 @@ M.parse_message = function(opts)
end
end
local user_prompt_obj = {
type = "text",
text = opts.user_prompt,
}
if Utils.tokens.calculate_tokens(opts.user_prompt) then user_prompt_obj.cache_control = { type = "ephemeral" } end
---@type {idx: integer, length: integer}[]
local user_prompts_with_length = {}
for idx, user_prompt in ipairs(opts.user_prompts) do
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
end
table.insert(message_content, user_prompt_obj)
table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end)
---@type table<integer, boolean>
local top_three = {}
for i = 1, math.min(3, #user_prompts_with_length) do
top_three[user_prompts_with_length[i].idx] = true
end
for idx, prompt_data in ipairs(opts.user_prompts) do
table.insert(message_content, {
type = "text",
text = prompt_data,
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
})
end
return {
{

View File

@@ -34,7 +34,7 @@ M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
M.parse_message = function(opts)
return {
preamble = opts.system_prompt,
message = opts.user_prompt,
message = table.concat(opts.user_prompts, "\n"),
}
end

View File

@@ -118,7 +118,7 @@ M.tokenizer_id = "gpt-4o"
M.parse_message = function(opts)
return {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = opts.user_prompt },
{ role = "user", content = table.concat(opts.user_prompts, "\n") },
}
end

View File

@@ -25,7 +25,7 @@ M.parse_message = function(opts)
end
-- insert a part into parts
table.insert(message_content, { text = opts.user_prompt })
table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") })
return {
systemInstruction = {

View File

@@ -16,7 +16,7 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---
---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string
---@field user_prompt string
---@field user_prompts string[]
---@field image_paths? string[]
---
---@class AvanteBaseMessage

View File

@@ -28,13 +28,12 @@ local M = {}
M.api_key_name = "OPENAI_API_KEY"
---@param opts AvantePromptOptions
M.get_user_message = function(opts) return opts.user_prompt end
M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end
M.parse_message = function(opts)
---@type string | OpenAIMessage[]
local user_content
---@type OpenAIMessage[]
local user_content = {}
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
user_content = {}
for _, image_path in ipairs(opts.image_paths) do
table.insert(user_content, {
type = "image_url",
@@ -43,9 +42,12 @@ M.parse_message = function(opts)
},
})
end
table.insert(user_content, { type = "text", text = opts.user_prompt })
vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end)
else
user_content = opts.user_prompt
user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
end
return {