fix: bedrock claude do not support prompt caching (#1507)
This commit is contained in:
@@ -13,6 +13,8 @@ M.role_map = {
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
---@param tool AvanteLLMTool
|
||||
---@return AvanteOpenAITool
|
||||
function M.transform_tool(tool)
|
||||
@@ -63,14 +65,14 @@ end
|
||||
|
||||
function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end
|
||||
|
||||
function M.parse_messages(opts)
|
||||
function M:parse_messages(opts)
|
||||
local messages = {}
|
||||
local provider = P[Config.provider]
|
||||
local base, _ = P.parse_config(provider)
|
||||
|
||||
-- NOTE: Handle the case where the selected model is the `o1` model
|
||||
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
|
||||
if M.is_o_series_model(base.model) then
|
||||
if self.is_o_series_model(base.model) then
|
||||
table.insert(messages, { role = "user", content = opts.system_prompt })
|
||||
else
|
||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||
@@ -100,20 +102,20 @@ function M.parse_messages(opts)
|
||||
vim.iter(messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role then
|
||||
if role == M.role_map["user"] then
|
||||
table.insert(final_messages, { role = M.role_map["assistant"], content = "Ok, I understand." })
|
||||
if role == self.role_map["user"] then
|
||||
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
|
||||
else
|
||||
table.insert(final_messages, { role = M.role_map["user"], content = "Ok" })
|
||||
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
||||
end
|
||||
end
|
||||
prev_role = role
|
||||
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
|
||||
table.insert(final_messages, { role = self.role_map[role] or role, content = message.content })
|
||||
end)
|
||||
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
table.insert(final_messages, {
|
||||
role = M.role_map["assistant"],
|
||||
role = self.role_map["assistant"],
|
||||
tool_calls = {
|
||||
{
|
||||
id = tool_history.tool_use.id,
|
||||
@@ -137,7 +139,7 @@ function M.parse_messages(opts)
|
||||
return final_messages
|
||||
end
|
||||
|
||||
function M.parse_response(ctx, data_stream, _, opts)
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
@@ -205,7 +207,7 @@ function M.parse_response(ctx, data_stream, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
function M.parse_response_without_stream(data, _, opts)
|
||||
function M:parse_response_without_stream(data, _, opts)
|
||||
---@type AvanteOpenAIChatResponse
|
||||
local json = vim.json.decode(data)
|
||||
if json.choices and json.choices[1] then
|
||||
@@ -217,7 +219,7 @@ function M.parse_response_without_stream(data, _, opts)
|
||||
end
|
||||
end
|
||||
|
||||
function M.parse_curl_args(provider, prompt_opts)
|
||||
function M:parse_curl_args(provider, prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(provider)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
@@ -240,8 +242,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
||||
end
|
||||
|
||||
-- NOTE: When using "o" series set the supported parameters only
|
||||
local stream = true
|
||||
if M.is_o_series_model(provider_conf.model) then
|
||||
if self.is_o_series_model(provider_conf.model) then
|
||||
request_body.max_completion_tokens = request_body.max_tokens
|
||||
request_body.max_tokens = nil
|
||||
request_body.temperature = 1
|
||||
@@ -251,7 +252,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
tools = {}
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(tools, M.transform_tool(tool))
|
||||
table.insert(tools, self.transform_tool(tool))
|
||||
end
|
||||
end
|
||||
|
||||
@@ -265,8 +266,8 @@ function M.parse_curl_args(provider, prompt_opts)
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = provider_conf.model,
|
||||
messages = M.parse_messages(prompt_opts),
|
||||
stream = stream,
|
||||
messages = self:parse_messages(prompt_opts),
|
||||
stream = true,
|
||||
tools = tools,
|
||||
}, request_body),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user