fix(claude): sending state manually (#84)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
@@ -204,10 +204,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
||||
---@alias AvanteCurlArgsBuilder fun(code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---
|
||||
---@class ResponseParser
|
||||
---@field event_state string
|
||||
---@field on_chunk fun(chunk: string): any
|
||||
---@field on_complete fun(err: string|nil): any
|
||||
---@alias AvanteResponseParser fun(data_stream: string, opts: ResponseParser): nil
|
||||
---@alias AvanteResponseParser fun(data_stream: string, event_state: string, opts: ResponseParser): nil
|
||||
---
|
||||
---@class AvanteProvider
|
||||
---@field endpoint string
|
||||
@@ -215,6 +214,9 @@ Remember: Accurate line numbers are CRITICAL. The range start_line to end_line m
|
||||
---@field api_key_name string
|
||||
---@field parse_response_data AvanteResponseParser
|
||||
---@field parse_curl_args fun(opts: AvanteProvider, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
---
|
||||
---@alias AvanteChunkParser fun(chunk: string): any
|
||||
---@alias AvanteCompleteParser fun(err: string|nil): nil
|
||||
|
||||
------------------------------Anthropic------------------------------
|
||||
|
||||
@@ -278,14 +280,14 @@ H.make_claude_message = function(opts)
|
||||
end
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_claude_response = function(data_stream, opts)
|
||||
if opts.event_state == "content_block_delta" then
|
||||
H.parse_claude_response = function(data_stream, event_state, opts)
|
||||
if event_state == "content_block_delta" then
|
||||
local json = vim.json.decode(data_stream)
|
||||
opts.on_chunk(json.delta.text)
|
||||
elseif opts.event_state == "message_stop" then
|
||||
elseif event_state == "message_stop" then
|
||||
opts.on_complete(nil)
|
||||
return
|
||||
elseif opts.event_state == "error" then
|
||||
elseif event_state == "error" then
|
||||
opts.on_complete(vim.json.decode(data_stream))
|
||||
end
|
||||
end
|
||||
@@ -351,7 +353,7 @@ H.make_openai_message = function(opts)
|
||||
end
|
||||
|
||||
---@type AvanteResponseParser
|
||||
H.parse_openai_response = function(data_stream, opts)
|
||||
H.parse_openai_response = function(data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
opts.on_complete(nil)
|
||||
return
|
||||
@@ -477,8 +479,8 @@ local active_job = nil
|
||||
---@param code_lang string
|
||||
---@param code_content string
|
||||
---@param selected_content_content string | nil
|
||||
---@param on_chunk fun(chunk: string): any
|
||||
---@param on_complete fun(err: string|nil): any
|
||||
---@param on_chunk AvanteChunkParser
|
||||
---@param on_complete AvanteCompleteParser
|
||||
M.stream = function(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
|
||||
local provider = Config.provider
|
||||
|
||||
@@ -488,7 +490,8 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
||||
code_content = code_content,
|
||||
selected_code_content = selected_content_content,
|
||||
}
|
||||
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete, event_state = nil }
|
||||
local current_event_state = nil
|
||||
local handler_opts = { on_chunk = on_chunk, on_complete = on_complete }
|
||||
|
||||
---@type AvanteCurlOutput
|
||||
local spec = nil
|
||||
@@ -502,6 +505,7 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
||||
ProviderConfig = Config.vendors[provider]
|
||||
spec = ProviderConfig.parse_curl_args(ProviderConfig, code_opts)
|
||||
end
|
||||
--- If the provider doesn't have stream set, we set it to true
|
||||
if spec.body.stream == nil then
|
||||
spec = vim.tbl_deep_extend("force", spec, {
|
||||
body = { stream = true },
|
||||
@@ -512,15 +516,15 @@ M.stream = function(question, code_lang, code_content, selected_content_content,
|
||||
local function parse_and_call(line)
|
||||
local event = line:match("^event: (.+)$")
|
||||
if event then
|
||||
handler_opts.event_state = event
|
||||
current_event_state = event
|
||||
return
|
||||
end
|
||||
local data_match = line:match("^data: (.+)$")
|
||||
if data_match then
|
||||
if ProviderConfig ~= nil then
|
||||
ProviderConfig.parse_response_data(data_match, handler_opts)
|
||||
ProviderConfig.parse_response_data(data_match, current_event_state, handler_opts)
|
||||
else
|
||||
H["parse_" .. provider .. "_response"](data_match, handler_opts)
|
||||
H["parse_" .. provider .. "_response"](data_match, current_event_state, handler_opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user