diff --git a/lua/avante/config.lua b/lua/avante/config.lua index 31031c4..eb59b93 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -544,6 +544,7 @@ M.BASE_PROVIDER_KEYS = { "tokenizer_id", "use_xml_format", "role_map", + "support_prompt_caching", "__inherited_from", "disable_tools", "entra", diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 662531d..49ce67f 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -232,7 +232,7 @@ function M._stream(opts) } ---@type AvanteCurlOutput - local spec = provider.parse_curl_args(provider, prompt_opts) + local spec = provider:parse_curl_args(provider, prompt_opts) local resp_ctx = {} @@ -244,11 +244,11 @@ function M._stream(opts) return end local data_match = line:match("^data: (.+)$") - if data_match then provider.parse_response(resp_ctx, data_match, current_event_state, handler_opts) end + if data_match then provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts) end end local function parse_response_without_stream(data) - provider.parse_response_without_stream(data, current_event_state, handler_opts) + provider:parse_response_without_stream(data, current_event_state, handler_opts) end local completed = false @@ -287,10 +287,10 @@ function M._stream(opts) { once = true } ) end - provider.parse_stream_data(resp_ctx, data, handler_opts) + provider:parse_stream_data(resp_ctx, data, handler_opts) else if provider.parse_stream_data ~= nil then - provider.parse_stream_data(resp_ctx, data, handler_opts) + provider:parse_stream_data(resp_ctx, data, handler_opts) else parse_stream_data(data) end @@ -357,7 +357,7 @@ function M._stream(opts) end -- If stream is not enabled, then handle the response here - if (spec.body.stream == nil or spec.body.stream == false) and result.status == 200 then + if provider:is_disable_stream() and result.status == 200 then vim.schedule(function() completed = true parse_response_without_stream(result.body) diff --git a/lua/avante/providers/azure.lua b/lua/avante/providers/azure.lua index c2842b5..04a7b14 100644 --- a/lua/avante/providers/azure.lua +++ b/lua/avante/providers/azure.lua @@ -16,8 +16,9 @@ M.api_key_name = "AZURE_OPENAI_API_KEY" M.parse_messages = O.parse_messages M.parse_response = O.parse_response M.parse_response_without_stream = O.parse_response_without_stream +M.is_disable_stream = O.is_disable_stream -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local provider_conf, request_body = P.parse_config(provider) local headers = { @@ -52,7 +53,7 @@ function M.parse_curl_args(provider, prompt_opts) insecure = provider_conf.allow_insecure, headers = headers, body = vim.tbl_deep_extend("force", { - messages = M.parse_messages(prompt_opts), + messages = self:parse_messages(prompt_opts), stream = true, }, request_body), } diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index bbc1bc4..3bfa4ae 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -18,17 +18,17 @@ function M.load_model_handler() error(error_msg) end -function M.parse_response(ctx, data_stream, event_state, opts) +function M:parse_response(ctx, data_stream, event_state, opts) local model_handler = M.load_model_handler() - return model_handler.parse_response(ctx, data_stream, event_state, opts) + return model_handler.parse_response(self, ctx, data_stream, event_state, opts) end -function M.build_bedrock_payload(prompt_opts, body_opts) +function M:build_bedrock_payload(prompt_opts, body_opts) local model_handler = M.load_model_handler() - return model_handler.build_bedrock_payload(prompt_opts, body_opts) + return model_handler.build_bedrock_payload(self, prompt_opts, body_opts) end -function M.parse_stream_data(ctx, data, opts) +function M:parse_stream_data(ctx, data, opts) -- @NOTE: Decode and process Bedrock response -- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON. -- The `type` field in the decoded JSON determines how the response is handled. @@ -37,11 +37,11 @@ function M.parse_stream_data(ctx, data, opts) local jsn = vim.json.decode(bedrock_data_match) local data_stream = vim.base64.decode(jsn.bytes) local json = vim.json.decode(data_stream) - M.parse_response(ctx, data_stream, json.type, opts) + self:parse_response(ctx, data_stream, json.type, opts) end end -function M.parse_response_without_stream(data, event_state, opts) +function M:parse_response_without_stream(data, event_state, opts) local bedrock_match = data:gmatch("exception(%b{})") opts.on_chunk("\n**Exception caught**\n\n") for bedrock_data_match in bedrock_match do @@ -54,7 +54,7 @@ end ---@param provider AvanteBedrockProviderFunctor ---@param prompt_opts AvantePromptOptions ---@return table -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local base, body_opts = P.parse_config(provider) local api_key = provider.parse_api_key() @@ -77,7 +77,7 @@ function M.parse_curl_args(provider, prompt_opts) if aws_session_token and aws_session_token ~= "" then headers["x-amz-security-token"] = aws_session_token end - local body_payload = M.build_bedrock_payload(prompt_opts, body_opts) + local body_payload = self:build_bedrock_payload(prompt_opts, body_opts) local rawArgs = { "--aws-sigv4", diff --git a/lua/avante/providers/bedrock/claude.lua b/lua/avante/providers/bedrock/claude.lua index d9bf2ee..a714821 100644 --- a/lua/avante/providers/bedrock/claude.lua +++ b/lua/avante/providers/bedrock/claude.lua @@ -11,20 +11,23 @@ local Claude = require("avante.providers.claude") ---@class AvanteBedrockModelHandler local M = {} +M.support_prompt_caching = false M.role_map = { user = "user", assistant = "assistant", } +M.is_disable_stream = Claude.is_disable_stream M.parse_messages = Claude.parse_messages M.parse_response = Claude.parse_response +---@param provider AvanteProviderFunctor ---@param prompt_opts AvantePromptOptions ---@param body_opts table ---@return table -function M.build_bedrock_payload(prompt_opts, body_opts) +function M.build_bedrock_payload(provider, prompt_opts, body_opts) local system_prompt = prompt_opts.system_prompt or "" - local messages = M.parse_messages(prompt_opts) + local messages = provider:parse_messages(prompt_opts) local max_tokens = body_opts.max_tokens or 2000 local payload = { anthropic_version = "bedrock-2023-05-31", diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index 1ba2135..dcfae6d 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -30,13 +30,16 @@ local M = {} M.api_key_name = "ANTHROPIC_API_KEY" M.use_xml_format = true +M.support_prompt_caching = true M.role_map = { user = "user", assistant = "assistant", } -function M.parse_messages(opts) +function M:is_disable_stream() return false end + +function M:parse_messages(opts) ---@type AvanteClaudeMessage[] local messages = {} @@ -50,13 +53,15 @@ function M.parse_messages(opts) ---@type table local top_two = {} - for i = 1, math.min(2, #messages_with_length) do - top_two[messages_with_length[i].idx] = true + if self.support_prompt_caching then + for i = 1, math.min(2, #messages_with_length) do + top_two[messages_with_length[i].idx] = true + end end for idx, message in ipairs(opts.messages) do table.insert(messages, { - role = M.role_map[message.role], + role = self.role_map[message.role], content = { { type = "text", @@ -142,7 +147,7 @@ function M.parse_messages(opts) return messages end -function M.parse_response(ctx, data_stream, event_state, opts) +function M:parse_response(ctx, data_stream, event_state, opts) if event_state == nil then if data_stream:match('"message_start"') then event_state = "message_start" @@ -260,7 +265,7 @@ end ---@param provider AvanteProviderFunctor ---@param prompt_opts AvantePromptOptions ---@return table -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local provider_conf, request_body = P.parse_config(provider) local disable_tools = provider_conf.disable_tools or false @@ -272,7 +277,7 @@ function M.parse_curl_args(provider, prompt_opts) if P.env.require_api_key(provider_conf) then headers["x-api-key"] = provider.parse_api_key() end - local messages = M.parse_messages(prompt_opts) + local messages = self:parse_messages(prompt_opts) local tools = {} if not disable_tools and prompt_opts.tools then @@ -281,7 +286,7 @@ function M.parse_curl_args(provider, prompt_opts) end end - if #tools > 0 then + if self.support_prompt_caching and #tools > 0 then local last_tool = vim.deepcopy(tools[#tools]) last_tool.cache_control = { type = "ephemeral" } tools[#tools] = last_tool diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index 6e9b577..8725a13 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -47,7 +47,9 @@ M.role_map = { assistant = "assistant", } -function M.parse_messages(opts) +function M:is_disable_stream() return false end + +function M:parse_messages(opts) local messages = { { role = "system", content = opts.system_prompt }, } @@ -57,7 +59,7 @@ function M.parse_messages(opts) return { messages = messages } end -function M.parse_stream_data(ctx, data, opts) +function M:parse_stream_data(ctx, data, opts) ---@type CohereChatResponse local json = vim.json.decode(data) if json.type ~= nil then @@ -69,7 +71,7 @@ function M.parse_stream_data(ctx, data, opts) end end -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local provider_conf, request_body = P.parse_config(provider) local headers = { @@ -92,7 +94,7 @@ function M.parse_curl_args(provider, prompt_opts) body = vim.tbl_deep_extend("force", { model = provider_conf.model, stream = true, - }, M.parse_messages(prompt_opts), request_body), + }, self:parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/providers/copilot.lua b/lua/avante/providers/copilot.lua index b792723..ddd22e3 100644 --- a/lua/avante/providers/copilot.lua +++ b/lua/avante/providers/copilot.lua @@ -209,11 +209,13 @@ M.role_map = { assistant = "assistant", } +function M:is_disable_stream() return false end + M.parse_messages = OpenAI.parse_messages M.parse_response = OpenAI.parse_response -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) -- refresh token synchronously, only if it has expired -- (this should rarely happen, as we refresh the token in the background) H.refresh_token(false, false) @@ -241,7 +243,7 @@ function M.parse_curl_args(provider, prompt_opts) }, body = vim.tbl_deep_extend("force", { model = provider_conf.model, - messages = M.parse_messages(prompt_opts), + messages = self:parse_messages(prompt_opts), stream = true, tools = tools, }, request_body), diff --git a/lua/avante/providers/gemini.lua b/lua/avante/providers/gemini.lua index a891e4e..14d15da 100644 --- a/lua/avante/providers/gemini.lua +++ b/lua/avante/providers/gemini.lua @@ -12,7 +12,9 @@ M.role_map = { } -- M.tokenizer_id = "google/gemma-2b" -function M.parse_messages(opts) +function M:is_disable_stream() return false end + +function M:parse_messages(opts) local contents = {} local prev_role = nil @@ -64,7 +66,7 @@ function M.parse_messages(opts) } end -function M.parse_response(ctx, data_stream, _, opts) +function M:parse_response(ctx, data_stream, _, opts) local ok, json = pcall(vim.json.decode, data_stream) if not ok then opts.on_stop({ reason = "error", error = json }) end if json.candidates then @@ -81,7 +83,7 @@ function M.parse_response(ctx, data_stream, _, opts) end end -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local provider_conf, request_body = P.parse_config(provider) request_body = vim.tbl_deep_extend("force", request_body, { @@ -104,7 +106,7 @@ function M.parse_curl_args(provider, prompt_opts) proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, headers = { ["Content-Type"] = "application/json" }, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body), + body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 9eff43a..36f53d0 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -13,6 +13,8 @@ M.role_map = { assistant = "assistant", } +function M:is_disable_stream() return false end + ---@param tool AvanteLLMTool ---@return AvanteOpenAITool function M.transform_tool(tool) @@ -63,14 +65,14 @@ end function M.is_o_series_model(model) return model and string.match(model, "^o%d+") ~= nil end -function M.parse_messages(opts) +function M:parse_messages(opts) local messages = {} local provider = P[Config.provider] local base, _ = P.parse_config(provider) -- NOTE: Handle the case where the selected model is the `o1` model -- "o1" models are "smart" enough to understand user prompt as a system prompt in this context - if M.is_o_series_model(base.model) then + if self.is_o_series_model(base.model) then table.insert(messages, { role = "user", content = opts.system_prompt }) else table.insert(messages, { role = "system", content = opts.system_prompt }) @@ -100,20 +102,20 @@ function M.parse_messages(opts) vim.iter(messages):each(function(message) local role = message.role if role == prev_role then - if role == M.role_map["user"] then - table.insert(final_messages, { role = M.role_map["assistant"], content = "Ok, I understand." }) + if role == self.role_map["user"] then + table.insert(final_messages, { role = self.role_map["assistant"], content = "Ok, I understand." }) else - table.insert(final_messages, { role = M.role_map["user"], content = "Ok" }) + table.insert(final_messages, { role = self.role_map["user"], content = "Ok" }) end end prev_role = role - table.insert(final_messages, { role = M.role_map[role] or role, content = message.content }) + table.insert(final_messages, { role = self.role_map[role] or role, content = message.content }) end) if opts.tool_histories then for _, tool_history in ipairs(opts.tool_histories) do table.insert(final_messages, { - role = M.role_map["assistant"], + role = self.role_map["assistant"], tool_calls = { { id = tool_history.tool_use.id, @@ -137,7 +139,7 @@ function M.parse_messages(opts) return final_messages end -function M.parse_response(ctx, data_stream, _, opts) +function M:parse_response(ctx, data_stream, _, opts) if data_stream:match('"%[DONE%]":') then opts.on_stop({ reason = "complete" }) return @@ -205,7 +207,7 @@ function M.parse_response(ctx, data_stream, _, opts) end end -function M.parse_response_without_stream(data, _, opts) +function M:parse_response_without_stream(data, _, opts) ---@type AvanteOpenAIChatResponse local json = vim.json.decode(data) if json.choices and json.choices[1] then @@ -217,7 +219,7 @@ function M.parse_response_without_stream(data, _, opts) end end -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local provider_conf, request_body = P.parse_config(provider) local disable_tools = provider_conf.disable_tools or false @@ -240,8 +242,7 @@ function M.parse_curl_args(provider, prompt_opts) end -- NOTE: When using "o" series set the supported parameters only - local stream = true - if M.is_o_series_model(provider_conf.model) then + if self.is_o_series_model(provider_conf.model) then request_body.max_completion_tokens = request_body.max_tokens request_body.max_tokens = nil request_body.temperature = 1 @@ -251,7 +252,7 @@ function M.parse_curl_args(provider, prompt_opts) if not disable_tools and prompt_opts.tools then tools = {} for _, tool in ipairs(prompt_opts.tools) do - table.insert(tools, M.transform_tool(tool)) + table.insert(tools, self.transform_tool(tool)) end end @@ -265,8 +266,8 @@ function M.parse_curl_args(provider, prompt_opts) headers = headers, body = vim.tbl_deep_extend("force", { model = provider_conf.model, - messages = M.parse_messages(prompt_opts), - stream = stream, + messages = self:parse_messages(prompt_opts), + stream = true, tools = tools, }, request_body), } diff --git a/lua/avante/providers/vertex.lua b/lua/avante/providers/vertex.lua index 9cc19cd..e1fb61e 100644 --- a/lua/avante/providers/vertex.lua +++ b/lua/avante/providers/vertex.lua @@ -11,6 +11,7 @@ M.role_map = { assistant = "model", } +M.is_disable_stream = Gemini.is_disable_stream M.parse_messages = Gemini.parse_messages M.parse_response = Gemini.parse_response @@ -31,11 +32,13 @@ function M.parse_api_key() return direct_output end -function M.parse_curl_args(provider, prompt_opts) +function M:parse_curl_args(provider, prompt_opts) local provider_conf, request_body = P.parse_config(provider) - local location = vim.fn.getenv("LOCATION") or "default-location" - local project_id = vim.fn.getenv("PROJECT_ID") or "default-project-id" + local location = vim.fn.getenv("LOCATION") + local project_id = vim.fn.getenv("PROJECT_ID") local model_id = provider_conf.model or "default-model-id" + if location == nil or location == vim.NIL then location = "default-location" end + if project_id == nil or project_id == vim.NIL then project_id = "default-project-id" end local url = provider_conf.endpoint:gsub("LOCATION", location):gsub("PROJECT_ID", project_id) url = string.format("%s/%s:streamGenerateContent?alt=sse", url, model_id) @@ -58,7 +61,7 @@ function M.parse_curl_args(provider, prompt_opts) }, proxy = provider_conf.proxy, insecure = provider_conf.allow_insecure, - body = vim.tbl_deep_extend("force", {}, M.parse_messages(prompt_opts), request_body), + body = vim.tbl_deep_extend("force", {}, self:parse_messages(prompt_opts), request_body), } end diff --git a/lua/avante/types.lua b/lua/avante/types.lua index 646e572..78591ba 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -187,16 +187,16 @@ vim.g.avante_login = vim.g.avante_login --- ---@alias AvanteChatMessage AvanteClaudeMessage | AvanteOpenAIMessage | AvanteGeminiMessage --- ----@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[] +---@alias AvanteMessagesParser fun(self: AvanteProviderFunctor, opts: AvantePromptOptions): AvanteChatMessage[] --- ---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table | string, headers: table, rawArgs: string[] | nil} ----@alias AvanteCurlArgsParser fun(provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput +---@alias AvanteCurlArgsParser fun(self: AvanteProviderFunctor, provider: AvanteProvider | AvanteProviderFunctor | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions): AvanteCurlOutput --- ---@class AvanteResponseParserOptions ---@field on_start AvanteLLMStartCallback ---@field on_chunk AvanteLLMChunkCallback ---@field on_stop AvanteLLMStopCallback ----@alias AvanteResponseParser fun(ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil +---@alias AvanteResponseParser fun(self: AvanteProviderFunctor, ctx: any, data_stream: string, event_state: string, opts: AvanteResponseParserOptions): nil --- ---@class AvanteDefaultBaseProvider: table ---@field endpoint? string @@ -248,7 +248,7 @@ vim.g.avante_login = vim.g.avante_login ---@field tool_use_list? AvanteLLMToolUse[] ---@field retry_after? integer --- ----@alias AvanteStreamParser fun(ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil +---@alias AvanteStreamParser fun(self: AvanteProviderFunctor, ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil ---@alias AvanteLLMChunkCallback fun(chunk: string): any ---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil @@ -260,10 +260,12 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_api_key? fun(): string | nil --- ---@class AvanteProviderFunctor +---@field support_prompt_caching boolean | nil ---@field role_map table<"user" | "assistant", string> ---@field parse_messages AvanteMessagesParser ---@field parse_response AvanteResponseParser ---@field parse_curl_args AvanteCurlArgsParser +---@field is_disable_stream fun(self: AvanteProviderFunctor): boolean ---@field setup fun(): nil ---@field is_env_set fun(): boolean ---@field api_key_name string @@ -274,11 +276,12 @@ vim.g.avante_login = vim.g.avante_login ---@field parse_stream_data? AvanteStreamParser ---@field on_error? fun(result: table): nil --- +---@alias AvanteBedrockPayloadBuilder fun(self: AvanteProviderFunctor, prompt_opts: AvantePromptOptions, body_opts: table): table +--- ---@class AvanteBedrockProviderFunctor: AvanteProviderFunctor ---@field load_model_handler fun(): AvanteBedrockModelHandler ----@field build_bedrock_payload? fun(prompt_opts: AvantePromptOptions, body_opts: table): table +---@field build_bedrock_payload? AvanteBedrockPayloadBuilder --- ----@alias AvanteBedrockPayloadBuilder fun(prompt_opts: AvantePromptOptions, body_opts: table): table --- ---@class AvanteBedrockModelHandler ---@field role_map table<"user" | "assistant", string>