From 6bd966e8e24808b8a1c067fd18be22d60c4b032a Mon Sep 17 00:00:00 2001 From: brook hong Date: Tue, 4 Mar 2025 00:20:27 +0800 Subject: [PATCH] fix: pass context to provider for stream data parsing (#1475) * fix: pass context to provider for stream data parsing * fix: luatype --------- Co-authored-by: yetone --- lua/avante/llm.lua | 6 ++++-- lua/avante/providers/bedrock.lua | 4 ++-- lua/avante/providers/cohere.lua | 2 +- lua/avante/types.lua | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 2594a6d..fb34e76 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -154,6 +154,8 @@ end function M._stream(opts) local provider = opts.provider or Providers[Config.provider] + ---@cast provider AvanteProviderFunctor + local prompt_opts = M.generate_prompts(opts) ---@type string @@ -285,10 +287,10 @@ function M._stream(opts) { once = true } ) end - provider.parse_stream_data(data, handler_opts) + provider.parse_stream_data(resp_ctx, data, handler_opts) else if provider.parse_stream_data ~= nil then - provider.parse_stream_data(data, handler_opts) + provider.parse_stream_data(resp_ctx, data, handler_opts) else parse_stream_data(data) end diff --git a/lua/avante/providers/bedrock.lua b/lua/avante/providers/bedrock.lua index b824dca..0906046 100644 --- a/lua/avante/providers/bedrock.lua +++ b/lua/avante/providers/bedrock.lua @@ -28,7 +28,7 @@ function M.build_bedrock_payload(prompt_opts, body_opts) return model_handler.build_bedrock_payload(prompt_opts, body_opts) end -function M.parse_stream_data(data, opts) +function M.parse_stream_data(ctx, data, opts) -- @NOTE: Decode and process Bedrock response -- Each response contains a Base64-encoded `bytes` field, which is decoded into JSON. -- The `type` field in the decoded JSON determines how the response is handled. @@ -37,7 +37,7 @@ function M.parse_stream_data(data, opts) local jsn = vim.json.decode(bedrock_data_match) local data_stream = vim.base64.decode(jsn.bytes) local json = vim.json.decode(data_stream) - M.parse_response({}, data_stream, json.type, opts) + M.parse_response(ctx, data_stream, json.type, opts) end end diff --git a/lua/avante/providers/cohere.lua b/lua/avante/providers/cohere.lua index 9869439..6e9b577 100644 --- a/lua/avante/providers/cohere.lua +++ b/lua/avante/providers/cohere.lua @@ -57,7 +57,7 @@ function M.parse_messages(opts) return { messages = messages } end -function M.parse_stream_data(data, opts) +function M.parse_stream_data(ctx, data, opts) ---@type CohereChatResponse local json = vim.json.decode(data) if json.type ~= nil then diff --git a/lua/avante/types.lua b/lua/avante/types.lua index a44eda1..0f2fc52 100644 --- a/lua/avante/types.lua +++ b/lua/avante/types.lua @@ -239,7 +239,7 @@ vim.g.avante_login = vim.g.avante_login ---@field tool_use_list? AvanteLLMToolUse[] ---@field retry_after? integer --- ----@alias AvanteStreamParser fun(line: string, handler_opts: AvanteHandlerOptions): nil +---@alias AvanteStreamParser fun(ctx: any, line: string, handler_opts: AvanteHandlerOptions): nil ---@alias AvanteLLMStartCallback fun(opts: AvanteLLMStartCallbackOptions): nil ---@alias AvanteLLMChunkCallback fun(chunk: string): any ---@alias AvanteLLMStopCallback fun(opts: AvanteLLMStopCallbackOptions): nil