fix: bedrock claude do not support prompt caching (#1507)
This commit is contained in:
@@ -544,6 +544,7 @@ M.BASE_PROVIDER_KEYS = {
|
|||||||
"tokenizer_id",
|
"tokenizer_id",
|
||||||
"use_xml_format",
|
"use_xml_format",
|
||||||
"role_map",
|
"role_map",
|
||||||
|
"support_prompt_caching",
|
||||||
"__inherited_from",
|
"__inherited_from",
|
||||||
"disable_tools",
|
"disable_tools",
|
||||||
"entra",
|
"entra",
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ function M._stream(opts)
|
|||||||
}
|
}
|
||||||
|
|
||||||
---@type AvanteCurlOutput
|
---@type AvanteCurlOutput
|
||||||
local spec = provider.parse_curl_args(provider, prompt_opts)
|
local spec = provider:parse_curl_args(provider, prompt_opts)
|
||||||
|
|
||||||
local resp_ctx = {}
|
local resp_ctx = {}
|
||||||
|
|
||||||
@@ -244,11 +244,11 @@ function M._stream(opts)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
local data_match = line:match("^data: (.+)$")
|
local data_match = line:match("^data: (.+)$")
|
||||||
if data_match then provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
|
if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end
|
||||||
end
|
end
|
||||||
|
|
||||||
local function parse_response_without_stream(data)
|
local function parse_response_without_stream(data)
|
||||||
provider.parse_response_without_stream(data, current_event_state, handler_opts)
|
provider:parse_response_without_stream(data, current_event_state, handler_opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
local completed = false
|
local completed = false
|
||||||
@@ -287,10 +287,10 @@ function M._stream(opts)
|
|||||||
{ once = true }
|
{ once = true }
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
provider.parse_stream_data(resp_ctx, data, handler_opts)
|
provider:parse_stream_data(resp_ctx, data, handler_opts)
|
||||||
else
|
else
|
||||||
if provider.parse_stream_data ~= nil then
|
if provider.parse_stream_data ~= nil then
|
||||||
provider.parse_stream_data(resp_ctx, data, handler_opts)
|
provider:parse_stream_data(resp_ctx, data, handler_opts)
|
||||||
else
|
else
|
||||||
parse_stream_data(data)
|
parse_stream_data(data)
|
||||||
end
|
end
|
||||||
@@ -357,7 +357,7 @@ function M._stream(opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
-- If stream is not enabled, then handle the response here
|
-- If stream is not enabled, then handle the response here
|
||||||
if (spec.body.stream == nil or spec.body.stream == false) and result.status == 200 then
|
if provider:is_disable_stream() and result.status == 200 then
|
||||||
vim.schedule(function()
|
vim.schedule(function()
|
||||||
completed = true
|
completed = true
|
||||||
parse_response_without_stream(result.body)
|
parse_response_without_stream(result.body)
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ M.api_key_name = "AZURE_OPENAI_API_KEY"
|
|||||||
M.parse_messages = O.parse_messages
|
M.parse_messages = O.parse_messages
|
||||||
M.parse_response = O.parse_response
|
M.parse_response = O.parse_response
|
||||||
M.parse_response_without_stream = O.parse_response_without_stream
|
M.parse_response_without_stream = O.parse_response_without_stream
|
||||||
|
M.is_disable_stream = O.is_disable_stream
|
||||||
|
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
@@ -52,7 +53,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
messages = M.parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
}, request_body),
|
}, request_body),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,17 +18,17 @@ function M.load_model_handler()
|
|||||||
error(error_msg)
|
error(error_msg)
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_response(ctx, data_stream, event_state, opts)
|
function M:parse_response(ctx, data_stream, event_state, opts)
|
||||||
local model_handler = M.load_model_handler()
|
local model_handler = M.load_model_handler()
|
||||||
return model_handler.parse_response(ctx, data_stream, event_state, opts)
|
return model_handler.parse_response(self, ctx, data_stream, event_state, opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.build_bedrock_payload(prompt_opts, body_opts)
|
function M:build_bedrock_payload(prompt_opts, body_opts)
|
||||||
local model_handler = M.load_model_handler()
|
local model_handler = M.load_model_handler()
|
||||||
return model_handler.build_bedrock_payload(prompt_opts, body_opts)
|
return model_handler.build_bedrock_payload(self, prompt_opts, body_opts)
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_stream_data(ctx, data, opts)
|
function M:parse_stream_data(ctx, data, opts)
|
||||||
-- @NOTE: Decode and process Bedrock response
|
-- @NOTE: Decode and process Bedrock response
|
||||||
-- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON.
|
-- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON.
|
||||||
-- The `type` field in the decoded JSON determines how the response is handled.
|
-- The `type` field in the decoded JSON determines how the response is handled.
|
||||||
@@ -37,11 +37,11 @@ function M.parse_stream_data(ctx, data, opts)
|
|||||||
local jsn = vim.json.decode(bedrock_data_match)
|
local jsn = vim.json.decode(bedrock_data_match)
|
||||||
local data_stream = vim.base64.decode(jsn.bytes)
|
local data_stream = vim.base64.decode(jsn.bytes)
|
||||||
local json = vim.json.decode(data_stream)
|
local json = vim.json.decode(data_stream)
|
||||||
M.parse_response(ctx, data_stream, json.type, opts)
|
self:parse_response(ctx, data_stream, json.type, opts)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_response_without_stream(data, event_state, opts)
|
function M:parse_response_without_stream(data, event_state, opts)
|
||||||
local bedrock_match = data:gmatch("exception(%b{})")
|
local bedrock_match = data:gmatch("exception(%b{})")
|
||||||
opts.on_chunk("\n**Exception caught**\n\n")
|
opts.on_chunk("\n**Exception caught**\n\n")
|
||||||
for bedrock_data_match in bedrock_match do
|
for bedrock_data_match in bedrock_match do
|
||||||
@@ -54,7 +54,7 @@ end
|
|||||||
---@param provider AvanteBedrockProviderFunctor
|
---@param provider AvanteBedrockProviderFunctor
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return table
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local base, body_opts = P.parse_config(provider)
|
local base, body_opts = P.parse_config(provider)
|
||||||
|
|
||||||
local api_key = provider.parse_api_key()
|
local api_key = provider.parse_api_key()
|
||||||
@@ -77,7 +77,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
|
|
||||||
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
|
if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end
|
||||||
|
|
||||||
local body_payload = M.build_bedrock_payload(prompt_opts, body_opts)
|
local body_payload = self:build_bedrock_payload(prompt_opts, body_opts)
|
||||||
|
|
||||||
local rawArgs = {
|
local rawArgs = {
|
||||||
"--aws-sigv4",
|
"--aws-sigv4",
|
||||||
|
|||||||
@@ -11,20 +11,23 @@ local Claude = require("avante.providers.claude")
|
|||||||
---@class AvanteBedrockModelHandler
|
---@class AvanteBedrockModelHandler
|
||||||
local M = {}
|
local M = {}
|
||||||
|
|
||||||
|
M.support_prompt_caching = false
|
||||||
M.role_map = {
|
M.role_map = {
|
||||||
user = "user",
|
user = "user",
|
||||||
assistant = "assistant",
|
assistant = "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
M.is_disable_stream = Claude.is_disable_stream
|
||||||
M.parse_messages = Claude.parse_messages
|
M.parse_messages = Claude.parse_messages
|
||||||
M.parse_response = Claude.parse_response
|
M.parse_response = Claude.parse_response
|
||||||
|
|
||||||
|
---@param provider AvanteProviderFunctor
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@param body_opts table
|
---@param body_opts table
|
||||||
---@return table
|
---@return table
|
||||||
function M.build_bedrock_payload(prompt_opts, body_opts)
|
function M.build_bedrock_payload(provider, prompt_opts, body_opts)
|
||||||
local system_prompt = prompt_opts.system_prompt or ""
|
local system_prompt = prompt_opts.system_prompt or ""
|
||||||
local messages = M.parse_messages(prompt_opts)
|
local messages = provider:parse_messages(prompt_opts)
|
||||||
local max_tokens = body_opts.max_tokens or 2000
|
local max_tokens = body_opts.max_tokens or 2000
|
||||||
local payload = {
|
local payload = {
|
||||||
anthropic_version = "bedrock-2023-05-31",
|
anthropic_version = "bedrock-2023-05-31",
|
||||||
|
|||||||
@@ -30,13 +30,16 @@ local M = {}
|
|||||||
|
|
||||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||||
M.use_xml_format = true
|
M.use_xml_format = true
|
||||||
|
M.support_prompt_caching = true
|
||||||
|
|
||||||
M.role_map = {
|
M.role_map = {
|
||||||
user = "user",
|
user = "user",
|
||||||
assistant = "assistant",
|
assistant = "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
function M.parse_messages(opts)
|
function M:is_disable_stream() return false end
|
||||||
|
|
||||||
|
function M:parse_messages(opts)
|
||||||
---@type AvanteClaudeMessage[]
|
---@type AvanteClaudeMessage[]
|
||||||
local messages = {}
|
local messages = {}
|
||||||
|
|
||||||
@@ -50,13 +53,15 @@ function M.parse_messages(opts)
|
|||||||
|
|
||||||
---@type table<integer, boolean>
|
---@type table<integer, boolean>
|
||||||
local top_two = {}
|
local top_two = {}
|
||||||
|
if self.support_prompt_caching then
|
||||||
for i = 1, math.min(2, #messages_with_length) do
|
for i = 1, math.min(2, #messages_with_length) do
|
||||||
top_two[messages_with_length[i].idx] = true
|
top_two[messages_with_length[i].idx] = true
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
for idx, message in ipairs(opts.messages) do
|
for idx, message in ipairs(opts.messages) do
|
||||||
table.insert(messages, {
|
table.insert(messages, {
|
||||||
role = M.role_map[message.role],
|
role = self.role_map[message.role],
|
||||||
content = {
|
content = {
|
||||||
{
|
{
|
||||||
type = "text",
|
type = "text",
|
||||||
@@ -142,7 +147,7 @@ function M.parse_messages(opts)
|
|||||||
return messages
|
return messages
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_response(ctx, data_stream, event_state, opts)
|
function M:parse_response(ctx, data_stream, event_state, opts)
|
||||||
if event_state == nil then
|
if event_state == nil then
|
||||||
if data_stream:match('"message_start"') then
|
if data_stream:match('"message_start"') then
|
||||||
event_state = "message_start"
|
event_state = "message_start"
|
||||||
@@ -260,7 +265,7 @@ end
|
|||||||
---@param provider AvanteProviderFunctor
|
---@param provider AvanteProviderFunctor
|
||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
---@return table
|
---@return table
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
@@ -272,7 +277,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
|
|
||||||
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
|
if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end
|
||||||
|
|
||||||
local messages = M.parse_messages(prompt_opts)
|
local messages = self:parse_messages(prompt_opts)
|
||||||
|
|
||||||
local tools = {}
|
local tools = {}
|
||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
@@ -281,7 +286,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if #tools > 0 then
|
if self.support_prompt_caching and #tools > 0 then
|
||||||
local last_tool = vim.deepcopy(tools[#tools])
|
local last_tool = vim.deepcopy(tools[#tools])
|
||||||
last_tool.cache_control = { type = "ephemeral" }
|
last_tool.cache_control = { type = "ephemeral" }
|
||||||
tools[#tools] = last_tool
|
tools[#tools] = last_tool
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ M.role_map = {
|
|||||||
assistant = "assistant",
|
assistant = "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
function M.parse_messages(opts)
|
function M:is_disable_stream() return false end
|
||||||
|
|
||||||
|
function M:parse_messages(opts)
|
||||||
local messages = {
|
local messages = {
|
||||||
{ role = "system", content = opts.system_prompt },
|
{ role = "system", content = opts.system_prompt },
|
||||||
}
|
}
|
||||||
@@ -57,7 +59,7 @@ function M.parse_messages(opts)
|
|||||||
return { messages = messages }
|
return { messages = messages }
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_stream_data(ctx, data, opts)
|
function M:parse_stream_data(ctx, data, opts)
|
||||||
---@type CohereChatResponse
|
---@type CohereChatResponse
|
||||||
local json = vim.json.decode(data)
|
local json = vim.json.decode(data)
|
||||||
if json.type ~= nil then
|
if json.type ~= nil then
|
||||||
@@ -69,7 +71,7 @@ function M.parse_stream_data(ctx, data, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
local headers = {
|
local headers = {
|
||||||
@@ -92,7 +94,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
stream = true,
|
stream = true,
|
||||||
}, M.parse_messages(prompt_opts), request_body),
|
}, self:parse_messages(prompt_opts), request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -209,11 +209,13 @@ M.role_map = {
|
|||||||
assistant = "assistant",
|
assistant = "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function M:is_disable_stream() return false end
|
||||||
|
|
||||||
M.parse_messages = OpenAI.parse_messages
|
M.parse_messages = OpenAI.parse_messages
|
||||||
|
|
||||||
M.parse_response = OpenAI.parse_response
|
M.parse_response = OpenAI.parse_response
|
||||||
|
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
-- refresh token synchronously, only if it has expired
|
-- refresh token synchronously, only if it has expired
|
||||||
-- (this should rarely happen, as we refresh the token in the background)
|
-- (this should rarely happen, as we refresh the token in the background)
|
||||||
H.refresh_token(false, false)
|
H.refresh_token(false, false)
|
||||||
@@ -241,7 +243,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
},
|
},
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
messages = M.parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
stream = true,
|
stream = true,
|
||||||
tools = tools,
|
tools = tools,
|
||||||
}, request_body),
|
}, request_body),
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ M.role_map = {
|
|||||||
}
|
}
|
||||||
-- M.tokenizer_id = "google/gemma-2b"
|
-- M.tokenizer_id = "google/gemma-2b"
|
||||||
|
|
||||||
function M.parse_messages(opts)
|
function M:is_disable_stream() return false end
|
||||||
|
|
||||||
|
function M:parse_messages(opts)
|
||||||
local contents = {}
|
local contents = {}
|
||||||
local prev_role = nil
|
local prev_role = nil
|
||||||
|
|
||||||
@@ -64,7 +66,7 @@ function M.parse_messages(opts)
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_response(ctx, data_stream, _, opts)
|
function M:parse_response(ctx, data_stream, _, opts)
|
||||||
local ok, json = pcall(vim.json.decode, data_stream)
|
local ok, json = pcall(vim.json.decode, data_stream)
|
||||||
if not ok then opts.on_stop({ reason = "error", error = json }) end
|
if not ok then opts.on_stop({ reason = "error", error = json }) end
|
||||||
if json.candidates then
|
if json.candidates then
|
||||||
@@ -81,7 +83,7 @@ function M.parse_response(ctx, data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
|
|
||||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||||
@@ -104,7 +106,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
headers = { ["Content-Type"] = "application/json" },
|
headers = { ["Content-Type"] = "application/json" },
|
||||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
|
body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ M.role_map = {
|
|||||||
assistant = "assistant",
|
assistant = "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function M:is_disable_stream() return false end
|
||||||
|
|
||||||
---@param tool AvanteLLMTool
|
---@param tool AvanteLLMTool
|
||||||
---@return AvanteOpenAITool
|
---@return AvanteOpenAITool
|
||||||
function M.transform_tool(tool)
|
function M.transform_tool(tool)
|
||||||
@@ -63,14 +65,14 @@ end
|
|||||||
|
|
||||||
function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end
|
function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end
|
||||||
|
|
||||||
function M.parse_messages(opts)
|
function M:parse_messages(opts)
|
||||||
local messages = {}
|
local messages = {}
|
||||||
local provider = P[Config.provider]
|
local provider = P[Config.provider]
|
||||||
local base, _ = P.parse_config(provider)
|
local base, _ = P.parse_config(provider)
|
||||||
|
|
||||||
-- NOTE: Handle the case where the selected model is the `o1` model
|
-- NOTE: Handle the case where the selected model is the `o1` model
|
||||||
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
|
-- "o1" models are "smart" enough to understand user prompt as a system prompt in this context
|
||||||
if M.is_o_series_model(base.model) then
|
if self.is_o_series_model(base.model) then
|
||||||
table.insert(messages, { role = "user", content = opts.system_prompt })
|
table.insert(messages, { role = "user", content = opts.system_prompt })
|
||||||
else
|
else
|
||||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||||
@@ -100,20 +102,20 @@ function M.parse_messages(opts)
|
|||||||
vim.iter(messages):each(function(message)
|
vim.iter(messages):each(function(message)
|
||||||
local role = message.role
|
local role = message.role
|
||||||
if role == prev_role then
|
if role == prev_role then
|
||||||
if role == M.role_map["user"] then
|
if role == self.role_map["user"] then
|
||||||
table.insert(final_messages, { role = M.role_map["assistant"], content = "Ok, I understand." })
|
table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." })
|
||||||
else
|
else
|
||||||
table.insert(final_messages, { role = M.role_map["user"], content = "Ok" })
|
table.insert(final_messages, { role = self.role_map["user"], content = "Ok" })
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
prev_role = role
|
prev_role = role
|
||||||
table.insert(final_messages, { role = M.role_map[role] or role, content = message.content })
|
table.insert(final_messages, { role = self.role_map[role] or role, content = message.content })
|
||||||
end)
|
end)
|
||||||
|
|
||||||
if opts.tool_histories then
|
if opts.tool_histories then
|
||||||
for _, tool_history in ipairs(opts.tool_histories) do
|
for _, tool_history in ipairs(opts.tool_histories) do
|
||||||
table.insert(final_messages, {
|
table.insert(final_messages, {
|
||||||
role = M.role_map["assistant"],
|
role = self.role_map["assistant"],
|
||||||
tool_calls = {
|
tool_calls = {
|
||||||
{
|
{
|
||||||
id = tool_history.tool_use.id,
|
id = tool_history.tool_use.id,
|
||||||
@@ -137,7 +139,7 @@ function M.parse_messages(opts)
|
|||||||
return final_messages
|
return final_messages
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_response(ctx, data_stream, _, opts)
|
function M:parse_response(ctx, data_stream, _, opts)
|
||||||
if data_stream:match('"%[DONE%]":') then
|
if data_stream:match('"%[DONE%]":') then
|
||||||
opts.on_stop({ reason = "complete" })
|
opts.on_stop({ reason = "complete" })
|
||||||
return
|
return
|
||||||
@@ -205,7 +207,7 @@ function M.parse_response(ctx, data_stream, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_response_without_stream(data, _, opts)
|
function M:parse_response_without_stream(data, _, opts)
|
||||||
---@type AvanteOpenAIChatResponse
|
---@type AvanteOpenAIChatResponse
|
||||||
local json = vim.json.decode(data)
|
local json = vim.json.decode(data)
|
||||||
if json.choices and json.choices[1] then
|
if json.choices and json.choices[1] then
|
||||||
@@ -217,7 +219,7 @@ function M.parse_response_without_stream(data, _, opts)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
local disable_tools = provider_conf.disable_tools or false
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
|
|
||||||
@@ -240,8 +242,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
end
|
end
|
||||||
|
|
||||||
-- NOTE: When using "o" series set the supported parameters only
|
-- NOTE: When using "o" series set the supported parameters only
|
||||||
local stream = true
|
if self.is_o_series_model(provider_conf.model) then
|
||||||
if M.is_o_series_model(provider_conf.model) then
|
|
||||||
request_body.max_completion_tokens = request_body.max_tokens
|
request_body.max_completion_tokens = request_body.max_tokens
|
||||||
request_body.max_tokens = nil
|
request_body.max_tokens = nil
|
||||||
request_body.temperature = 1
|
request_body.temperature = 1
|
||||||
@@ -251,7 +252,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
tools = {}
|
tools = {}
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(tools, M.transform_tool(tool))
|
table.insert(tools, self.transform_tool(tool))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -265,8 +266,8 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
headers = headers,
|
headers = headers,
|
||||||
body = vim.tbl_deep_extend("force", {
|
body = vim.tbl_deep_extend("force", {
|
||||||
model = provider_conf.model,
|
model = provider_conf.model,
|
||||||
messages = M.parse_messages(prompt_opts),
|
messages = self:parse_messages(prompt_opts),
|
||||||
stream = stream,
|
stream = true,
|
||||||
tools = tools,
|
tools = tools,
|
||||||
}, request_body),
|
}, request_body),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ M.role_map = {
|
|||||||
assistant = "model",
|
assistant = "model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
M.is_disable_stream = Gemini.is_disable_stream
|
||||||
M.parse_messages = Gemini.parse_messages
|
M.parse_messages = Gemini.parse_messages
|
||||||
M.parse_response = Gemini.parse_response
|
M.parse_response = Gemini.parse_response
|
||||||
|
|
||||||
@@ -31,11 +32,13 @@ function M.parse_api_key()
|
|||||||
return direct_output
|
return direct_output
|
||||||
end
|
end
|
||||||
|
|
||||||
function M.parse_curl_args(provider, prompt_opts)
|
function M:parse_curl_args(provider, prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(provider)
|
local provider_conf, request_body = P.parse_config(provider)
|
||||||
local location = vim.fn.getenv("LOCATION") or "default-location"
|
local location = vim.fn.getenv("LOCATION")
|
||||||
local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id"
|
local project_id = vim.fn.getenv("PROJECT_ID")
|
||||||
local model_id = provider_conf.model or "default-model-id"
|
local model_id = provider_conf.model or "default-model-id"
|
||||||
|
if location == nil or location == vim.NIL then location = "default-location" end
|
||||||
|
if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end
|
||||||
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id)
|
||||||
|
|
||||||
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id)
|
||||||
@@ -58,7 +61,7 @@ function M.parse_curl_args(provider, prompt_opts)
|
|||||||
},
|
},
|
||||||
proxy = provider_conf.proxy,
|
proxy = provider_conf.proxy,
|
||||||
insecure = provider_conf.allow_insecure,
|
insecure = provider_conf.allow_insecure,
|
||||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body),
|
body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body),
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -187,16 +187,16 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---
|
---
|
||||||
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage
|
---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage
|
||||||
---
|
---
|
||||||
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[]
|
||||||
---
|
---
|
||||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>, rawArgs: string[] | nil}
|
||||||
---@alias AvanteCurlArgsParser fun(provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput
|
||||||
---
|
---
|
||||||
---@class AvanteResponseParserOptions
|
---@class AvanteResponseParserOptions
|
||||||
---@field on_start AvanteLLMStartCallback
|
---@field on_start AvanteLLMStartCallback
|
||||||
---@field on_chunk AvanteLLMChunkCallback
|
---@field on_chunk AvanteLLMChunkCallback
|
||||||
---@field on_stop AvanteLLMStopCallback
|
---@field on_stop AvanteLLMStopCallback
|
||||||
---@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil
|
---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil
|
||||||
---
|
---
|
||||||
---@class AvanteDefaultBaseProvider: table<string, any>
|
---@class AvanteDefaultBaseProvider: table<string, any>
|
||||||
---@field endpoint? string
|
---@field endpoint? string
|
||||||
@@ -248,7 +248,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field tool_use_list? AvanteLLMToolUse[]
|
---@field tool_use_list? AvanteLLMToolUse[]
|
||||||
---@field retry_after? integer
|
---@field retry_after? integer
|
||||||
---
|
---
|
||||||
---@alias AvanteStreamParser fun(ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
|
---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil
|
||||||
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
|
---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil
|
||||||
---@alias AvanteLLMChunkCallback fun(chunk: string): any
|
---@alias AvanteLLMChunkCallback fun(chunk: string): any
|
||||||
---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil
|
---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil
|
||||||
@@ -260,10 +260,12 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_api_key? fun(): string | nil
|
---@field parse_api_key? fun(): string | nil
|
||||||
---
|
---
|
||||||
---@class AvanteProviderFunctor
|
---@class AvanteProviderFunctor
|
||||||
|
---@field support_prompt_caching boolean | nil
|
||||||
---@field role_map table<"user" | "assistant", string>
|
---@field role_map table<"user" | "assistant", string>
|
||||||
---@field parse_messages AvanteMessagesParser
|
---@field parse_messages AvanteMessagesParser
|
||||||
---@field parse_response AvanteResponseParser
|
---@field parse_response AvanteResponseParser
|
||||||
---@field parse_curl_args AvanteCurlArgsParser
|
---@field parse_curl_args AvanteCurlArgsParser
|
||||||
|
---@field is_disable_stream fun(self: AvanteProviderFunctor): boolean
|
||||||
---@field setup fun(): nil
|
---@field setup fun(): nil
|
||||||
---@field is_env_set fun(): boolean
|
---@field is_env_set fun(): boolean
|
||||||
---@field api_key_name string
|
---@field api_key_name string
|
||||||
@@ -274,11 +276,12 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---@field on_error? fun(result: table<string, any>): nil
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
---
|
---
|
||||||
|
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
||||||
|
---
|
||||||
---@class AvanteBedrockProviderFunctor: AvanteProviderFunctor
|
---@class AvanteBedrockProviderFunctor: AvanteProviderFunctor
|
||||||
---@field load_model_handler fun(): AvanteBedrockModelHandler
|
---@field load_model_handler fun(): AvanteBedrockModelHandler
|
||||||
---@field build_bedrock_payload? fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
---@field build_bedrock_payload? AvanteBedrockPayloadBuilder
|
||||||
---
|
---
|
||||||
---@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table<string, any>): table<string, any>
|
|
||||||
---
|
---
|
||||||
---@class AvanteBedrockModelHandler
|
---@class AvanteBedrockModelHandler
|
||||||
---@field role_map table<"user" | "assistant", string>
|
---@field role_map table<"user" | "assistant", string>
|
||||||
|
|||||||
Reference in New Issue
Block a user