fix: bedrock claude do not support prompt caching (#1507)

This commit is contained in:
yetone
2025-03-06 18:31:56 +08:00
committed by GitHub
parent 6d39e06f57
commit 5aa55689ff
12 changed files with 85 additions and 62 deletions

View File

@@ -13,6 +13,8 @@ M.role_map = {
assistant = "assistant",
}
function M:is_disable_stream() return false end
---@param tool AvanteLLMTool
---@return AvanteOpenAITool
function M.transform_tool(tool)
@@ -63,14 +65,14 @@ end
function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end
function M.parse_messages(opts)
function M:parse_messages(opts)
local messages = {}
local provider = P[Config.provider]
local base, _ = P.parse_config(provider)
-- NOTE: Handle the case where the selected model is the `o1` model
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
if M.is_o_series_model(base.model) then
if self.is_o_series_model(base.model) then
table.insert(messages, { role = "user", content = opts.system_prompt })
else
table.insert(messages, { role = "system", content = opts.system_prompt })
@@ -100,20 +102,20 @@ function M.parse_messages(opts)
vim.iter(messages):each(function(message)
local role = message.role
if role == prev_role then
if role == M.role_map["user"] then
table.insert(final_messages, { role = M.role_map["assistant"], content = "Ok, I understand." })
if role == self.role_map["user"] then
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
else
table.insert(final_messages, { role = M.role_map["user"], content = "Ok" })
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
end
end
prev_role = role
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
table.insert(final_messages, { role = self.role_map[role] or role, content = message.content })
end)
if opts.tool_histories then
for _, tool_history in ipairs(opts.tool_histories) do
table.insert(final_messages, {
role = M.role_map["assistant"],
role = self.role_map["assistant"],
tool_calls = {
{
id = tool_history.tool_use.id,
@@ -137,7 +139,7 @@ function M.parse_messages(opts)
return final_messages
end
function M.parse_response(ctx, data_stream, _, opts)
function M:parse_response(ctx, data_stream, _, opts)
if data_stream:match('"%[DONE%]":') then
opts.on_stop({ reason = "complete" })
return
@@ -205,7 +207,7 @@ function M.parse_response(ctx, data_stream, _, opts)
end
end
function M.parse_response_without_stream(data, _, opts)
function M:parse_response_without_stream(data, _, opts)
---@type AvanteOpenAIChatResponse
local json = vim.json.decode(data)
if json.choices and json.choices[1] then
@@ -217,7 +219,7 @@ function M.parse_response_without_stream(data, _, opts)
end
end
function M.parse_curl_args(provider, prompt_opts)
function M:parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local disable_tools = provider_conf.disable_tools or false
@@ -240,8 +242,7 @@ function M.parse_curl_args(provider, prompt_opts)
end
-- NOTE: When using "o" series set the supported parameters only
local stream = true
if M.is_o_series_model(provider_conf.model) then
if self.is_o_series_model(provider_conf.model) then
request_body.max_completion_tokens = request_body.max_tokens
request_body.max_tokens = nil
request_body.temperature = 1
@@ -251,7 +252,7 @@ function M.parse_curl_args(provider, prompt_opts)
if not disable_tools and prompt_opts.tools then
tools = {}
for _, tool in ipairs(prompt_opts.tools) do
table.insert(tools, M.transform_tool(tool))
table.insert(tools, self.transform_tool(tool))
end
end
@@ -265,8 +266,8 @@ function M.parse_curl_args(provider, prompt_opts)
headers = headers,
body = vim.tbl_deep_extend("force", {
model = provider_conf.model,
messages = M.parse_messages(prompt_opts),
stream = stream,
messages = self:parse_messages(prompt_opts),
stream = true,
tools = tools,
}, request_body),
}