refactor: history messages (#1934)
This commit is contained in:
@@ -58,6 +58,7 @@ function M:parse_stream_data(ctx, data, opts)
|
||||
end
|
||||
|
||||
function M:parse_response_without_stream(data, event_state, opts)
|
||||
if opts.on_chunk == nil then return end
|
||||
local bedrock_match = data:gmatch("exception(%b{})")
|
||||
opts.on_chunk("\n**Exception caught**\n\n")
|
||||
for bedrock_data_match in bedrock_match do
|
||||
|
||||
@@ -2,7 +2,7 @@ local Utils = require("avante.utils")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local P = require("avante.providers")
|
||||
local Config = require("avante.config")
|
||||
local StreamingJsonParser = require("avante.utils.streaming_json_parser")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -139,63 +139,6 @@ function M:parse_messages(opts)
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
if tool_history.tool_use then
|
||||
local msg = {
|
||||
role = "assistant",
|
||||
content = {},
|
||||
}
|
||||
if tool_history.tool_use.thinking_blocks then
|
||||
for _, thinking_block in ipairs(tool_history.tool_use.thinking_blocks) do
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "thinking",
|
||||
thinking = thinking_block.thinking,
|
||||
signature = thinking_block.signature,
|
||||
}
|
||||
end
|
||||
end
|
||||
if tool_history.tool_use.redacted_thinking_blocks then
|
||||
for _, redacted_thinking_block in ipairs(tool_history.tool_use.redacted_thinking_blocks) do
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "redacted_thinking",
|
||||
data = redacted_thinking_block.data,
|
||||
}
|
||||
end
|
||||
end
|
||||
if tool_history.tool_use.response_contents then
|
||||
for _, response_content in ipairs(tool_history.tool_use.response_contents) do
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "text",
|
||||
text = response_content,
|
||||
}
|
||||
end
|
||||
end
|
||||
msg.content[#msg.content + 1] = {
|
||||
type = "tool_use",
|
||||
id = tool_history.tool_use.id,
|
||||
name = tool_history.tool_use.name,
|
||||
input = vim.json.decode(tool_history.tool_use.input_json),
|
||||
}
|
||||
messages[#messages + 1] = msg
|
||||
end
|
||||
|
||||
if tool_history.tool_result then
|
||||
messages[#messages + 1] = {
|
||||
role = "user",
|
||||
content = {
|
||||
{
|
||||
type = "tool_result",
|
||||
tool_use_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.content,
|
||||
is_error = tool_history.tool_result.is_error,
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
|
||||
@@ -226,14 +169,51 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
local content_block = jsn.content_block
|
||||
content_block.stoppped = false
|
||||
ctx.content_blocks[jsn.index + 1] = content_block
|
||||
if content_block.type == "thinking" then opts.on_chunk("<think>\n") end
|
||||
if content_block.type == "tool_use" and opts.on_partial_tool_use then
|
||||
opts.on_partial_tool_use({
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
partial_json = {},
|
||||
if content_block.type == "text" then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generating",
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
if content_block.type == "thinking" then
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
if opts.on_messages_add then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = content_block.thinking,
|
||||
signature = content_block.signature,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
end
|
||||
if content_block.type == "tool_use" and opts.on_messages_add then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
input = {},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
})
|
||||
content_block.uuid = msg.uuid
|
||||
opts.on_messages_add({ msg })
|
||||
end
|
||||
elseif event_state == "content_block_delta" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
@@ -242,23 +222,35 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if jsn.delta.type == "input_json_delta" then
|
||||
if not content_block.input_json then content_block.input_json = "" end
|
||||
content_block.input_json = content_block.input_json .. jsn.delta.partial_json
|
||||
if opts.on_partial_tool_use then
|
||||
local streaming_json_parser = StreamingJsonParser:new()
|
||||
local partial_json = streaming_json_parser:parse(content_block.input_json)
|
||||
opts.on_partial_tool_use({
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
partial_json = partial_json or {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
return
|
||||
elseif jsn.delta.type == "thinking_delta" then
|
||||
content_block.thinking = content_block.thinking .. jsn.delta.thinking
|
||||
opts.on_chunk(jsn.delta.thinking)
|
||||
if opts.on_chunk then opts.on_chunk(jsn.delta.thinking) end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = content_block.thinking,
|
||||
signature = content_block.signature,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "text_delta" then
|
||||
content_block.text = content_block.text .. jsn.delta.text
|
||||
opts.on_chunk(jsn.delta.text)
|
||||
if opts.on_chunk then opts.on_chunk(jsn.delta.text) end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generating",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
elseif jsn.delta.type == "signature_delta" then
|
||||
if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end
|
||||
ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature
|
||||
@@ -268,12 +260,56 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if not ok then return end
|
||||
local content_block = ctx.content_blocks[jsn.index + 1]
|
||||
content_block.stoppped = true
|
||||
if content_block.type == "text" then
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = content_block.text,
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
if content_block.type == "tool_use" then
|
||||
local complete_json = vim.json.decode(content_block.input_json)
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = content_block.name,
|
||||
id = content_block.id,
|
||||
input = complete_json or {},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
if content_block.type == "thinking" then
|
||||
if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then
|
||||
opts.on_chunk("\n</think>\n\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n\n")
|
||||
if opts.on_chunk then
|
||||
if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then
|
||||
opts.on_chunk("\n</think>\n\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n\n")
|
||||
end
|
||||
end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = content_block.thinking,
|
||||
signature = content_block.signature,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = "generated",
|
||||
uuid = content_block.uuid,
|
||||
})
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
elseif event_state == "message_delta" then
|
||||
local ok, jsn = pcall(vim.json.decode, data_stream)
|
||||
@@ -281,49 +317,20 @@ function M:parse_response(ctx, data_stream, event_state, opts)
|
||||
if jsn.delta.stop_reason == "end_turn" then
|
||||
opts.on_stop({ reason = "complete", usage = jsn.usage })
|
||||
elseif jsn.delta.stop_reason == "tool_use" then
|
||||
---@type AvanteLLMToolUse[]
|
||||
local tool_use_list = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(function(content_block) return content_block.stoppped and content_block.type == "tool_use" end)
|
||||
:map(function(content_block)
|
||||
local response_contents = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "text" end)
|
||||
:map(function(content_block_) return content_block_.text end)
|
||||
:totable()
|
||||
local thinking_blocks = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(function(content_block_) return content_block_.stoppped and content_block_.type == "thinking" end)
|
||||
:map(function(content_block_)
|
||||
---@type AvanteLLMThinkingBlock
|
||||
return { thinking = content_block_.thinking, signature = content_block_.signature }
|
||||
end)
|
||||
:totable()
|
||||
local redacted_thinking_blocks = vim
|
||||
.iter(ctx.content_blocks)
|
||||
:filter(
|
||||
function(content_block_) return content_block_.stoppped and content_block_.type == "redacted_thinking" end
|
||||
)
|
||||
:map(function(content_block_)
|
||||
---@type AvanteLLMRedactedThinkingBlock
|
||||
return { data = content_block_.data }
|
||||
end)
|
||||
:totable()
|
||||
---@type AvanteLLMToolUse
|
||||
return {
|
||||
name = content_block.name,
|
||||
local tool_use_list = {}
|
||||
for _, content_block in ipairs(ctx.content_blocks) do
|
||||
if content_block.type == "tool_use" then
|
||||
table.insert(tool_use_list, {
|
||||
id = content_block.id,
|
||||
name = content_block.name,
|
||||
input_json = content_block.input_json,
|
||||
response_contents = response_contents,
|
||||
thinking_blocks = thinking_blocks,
|
||||
redacted_thinking_blocks = redacted_thinking_blocks,
|
||||
}
|
||||
end)
|
||||
:totable()
|
||||
})
|
||||
end
|
||||
end
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
-- tool_use_list = tool_use_list,
|
||||
usage = jsn.usage,
|
||||
tool_use_list = tool_use_list,
|
||||
})
|
||||
end
|
||||
return
|
||||
@@ -351,7 +358,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
local tools = {}
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
if Config.behaviour.enable_claude_text_editor_tool_mode then
|
||||
if Config.mode == "agentic" then
|
||||
if tool.name == "create_file" then goto continue end
|
||||
if tool.name == "view" then goto continue end
|
||||
if tool.name == "str_replace" then goto continue end
|
||||
@@ -364,7 +371,7 @@ function M:parse_curl_args(prompt_opts)
|
||||
end
|
||||
end
|
||||
|
||||
if prompt_opts.tools and #prompt_opts.tools > 0 and Config.behaviour.enable_claude_text_editor_tool_mode then
|
||||
if prompt_opts.tools and #prompt_opts.tools > 0 and Config.mode == "agentic" then
|
||||
if provider_conf.model:match("claude%-3%-7%-sonnet") then
|
||||
table.insert(tools, {
|
||||
type = "text_editor_20250124",
|
||||
|
||||
@@ -211,11 +211,7 @@ M.role_map = {
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
M.parse_messages = OpenAI.parse_messages
|
||||
|
||||
M.parse_response = OpenAI.parse_response
|
||||
|
||||
M.is_reasoning_model = OpenAI.is_reasoning_model
|
||||
setmetatable(M, { __index = OpenAI })
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
-- refresh token synchronously, only if it has expired
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local OpenAI = require("avante.providers").openai
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -10,14 +11,32 @@ M.role_map = {
|
||||
user = "user",
|
||||
assistant = "model",
|
||||
}
|
||||
-- M.tokenizer_id = "google/gemma-2b"
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
---@param tool AvanteLLMTool
|
||||
function M:transform_to_function_declaration(tool)
|
||||
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
|
||||
local parameters = nil
|
||||
if not vim.tbl_isempty(input_schema_properties) then
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = input_schema_properties,
|
||||
required = required,
|
||||
}
|
||||
end
|
||||
return {
|
||||
name = tool.name,
|
||||
description = tool.get_description and tool.get_description() or tool.description,
|
||||
parameters = parameters,
|
||||
}
|
||||
end
|
||||
|
||||
function M:parse_messages(opts)
|
||||
local contents = {}
|
||||
local prev_role = nil
|
||||
|
||||
local tool_id_to_name = {}
|
||||
vim.iter(opts.messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role then
|
||||
@@ -54,9 +73,27 @@ function M:parse_messages(opts)
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_use" then
|
||||
table.insert(parts, { text = item.name })
|
||||
tool_id_to_name[item.id] = item.name
|
||||
role = "model"
|
||||
table.insert(parts, {
|
||||
functionCall = {
|
||||
name = item.name,
|
||||
args = item.input,
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_result" then
|
||||
table.insert(parts, { text = item.content })
|
||||
role = "function"
|
||||
local ok, content = pcall(vim.json.decode, item.content)
|
||||
if not ok then content = item.content end
|
||||
table.insert(parts, {
|
||||
functionResponse = {
|
||||
name = tool_id_to_name[item.tool_use_id],
|
||||
response = {
|
||||
name = tool_id_to_name[item.tool_use_id],
|
||||
content = content,
|
||||
},
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "thinking" then
|
||||
table.insert(parts, { text = item.thinking })
|
||||
elseif type(item) == "table" and item.type == "redacted_thinking" then
|
||||
@@ -96,22 +133,43 @@ end
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
local ok, json = pcall(vim.json.decode, data_stream)
|
||||
if not ok then opts.on_stop({ reason = "error", error = json }) end
|
||||
if json.candidates then
|
||||
if #json.candidates > 0 then
|
||||
if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then
|
||||
opts.on_chunk(json.candidates[1].content.parts[1].text)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
else
|
||||
opts.on_chunk(json.candidates[1].content.parts[1].text)
|
||||
if json.candidates and #json.candidates > 0 then
|
||||
local candidate = json.candidates[1]
|
||||
---@type AvanteLLMToolUse[]
|
||||
local tool_use_list = {}
|
||||
for _, part in ipairs(candidate.content.parts) do
|
||||
if part.text then
|
||||
if opts.on_chunk then opts.on_chunk(part.text) end
|
||||
OpenAI:add_text_message(ctx, part.text, "generating", opts)
|
||||
elseif part.functionCall then
|
||||
if not ctx.function_call_id then ctx.function_call_id = 0 end
|
||||
ctx.function_call_id = ctx.function_call_id + 1
|
||||
local tool_use = {
|
||||
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
|
||||
name = part.functionCall.name,
|
||||
input_json = vim.json.encode(part.functionCall.args),
|
||||
}
|
||||
table.insert(tool_use_list, tool_use)
|
||||
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
||||
end
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
if candidate.finishReason and candidate.finishReason == "STOP" then
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
if #tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
end
|
||||
else
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
generationConfig = {
|
||||
@@ -125,6 +183,21 @@ function M:parse_curl_args(prompt_opts)
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||
|
||||
local function_declarations = {}
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(function_declarations, self:transform_to_function_declaration(tool))
|
||||
end
|
||||
end
|
||||
|
||||
if #function_declarations > 0 then
|
||||
request_body.tools = {
|
||||
{
|
||||
functionDeclarations = function_declarations,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(
|
||||
provider_conf.endpoint,
|
||||
|
||||
@@ -215,14 +215,6 @@ function M.setup()
|
||||
E.setup({ provider = auto_suggestions_provider })
|
||||
end
|
||||
|
||||
if Config.behaviour.enable_cursor_planning_mode then
|
||||
local cursor_applying_provider_name = Config.cursor_applying_provider or Config.provider
|
||||
local cursor_applying_provider = M[cursor_applying_provider_name]
|
||||
if cursor_applying_provider and cursor_applying_provider ~= provider then
|
||||
E.setup({ provider = cursor_applying_provider })
|
||||
end
|
||||
end
|
||||
|
||||
if Config.memory_summary_provider then
|
||||
local memory_summary_provider = M[Config.memory_summary_provider]
|
||||
if memory_summary_provider and memory_summary_provider ~= provider then
|
||||
@@ -277,4 +269,13 @@ function M.get_config(provider_name)
|
||||
return type(cur) == "function" and cur() or cur
|
||||
end
|
||||
|
||||
function M.get_memory_summary_provider()
|
||||
local provider_name = Config.memory_summary_provider
|
||||
if provider_name == nil then
|
||||
if M.openai.is_env_set() then provider_name = "openai-gpt-4o-mini" end
|
||||
end
|
||||
if provider_name == nil then provider_name = Config.provider end
|
||||
return M[provider_name]
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -16,7 +16,7 @@ M.is_reasoning_model = P.openai.is_reasoning_model
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
function M:parse_stream_data(ctx, data, handler_opts)
|
||||
function M:parse_stream_data(ctx, data, opts)
|
||||
local ok, json_data = pcall(vim.json.decode, data)
|
||||
if not ok or not json_data then
|
||||
-- Add debug logging
|
||||
@@ -26,11 +26,13 @@ function M:parse_stream_data(ctx, data, handler_opts)
|
||||
|
||||
if json_data.message and json_data.message.content then
|
||||
local content = json_data.message.content
|
||||
if content and content ~= "" then handler_opts.on_chunk(content) end
|
||||
P.openai:add_text_message(ctx, content, "generating", opts)
|
||||
if content and content ~= "" and opts.on_chunk then opts.on_chunk(content) end
|
||||
end
|
||||
|
||||
if json_data.done then
|
||||
handler_opts.on_stop({ reason = "complete" })
|
||||
P.openai:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
@@ -2,7 +2,7 @@ local Utils = require("avante.utils")
|
||||
local Config = require("avante.config")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local Providers = require("avante.providers")
|
||||
local StreamingJsonParser = require("avante.utils.streaming_json_parser")
|
||||
local HistoryMessage = require("avante.history_message")
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -164,117 +164,154 @@ function M:parse_messages(opts)
|
||||
table.insert(final_messages, message)
|
||||
end)
|
||||
|
||||
if opts.tool_histories then
|
||||
for _, tool_history in ipairs(opts.tool_histories) do
|
||||
table.insert(final_messages, {
|
||||
role = self.role_map["assistant"],
|
||||
tool_calls = {
|
||||
{
|
||||
id = tool_history.tool_use.id,
|
||||
type = "function",
|
||||
["function"] = {
|
||||
name = tool_history.tool_use.name,
|
||||
arguments = tool_history.tool_use.input_json,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
local result_content = tool_history.tool_result.content or ""
|
||||
table.insert(final_messages, {
|
||||
role = "tool",
|
||||
tool_call_id = tool_history.tool_result.tool_use_id,
|
||||
content = tool_history.tool_result.is_error and "Error: " .. result_content or result_content,
|
||||
})
|
||||
return final_messages
|
||||
end
|
||||
|
||||
function M:finish_pending_messages(ctx, opts)
|
||||
if ctx.content ~= nil and ctx.content ~= "" then self:add_text_message(ctx, "", "generated", opts) end
|
||||
if ctx.tool_use_list then
|
||||
for _, tool_use in ipairs(ctx.tool_use_list) do
|
||||
if tool_use.state == "generating" then self:add_tool_use_message(tool_use, "generated", opts) end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return final_messages
|
||||
function M:add_text_message(ctx, text, state, opts)
|
||||
if ctx.content == nil then ctx.content = "" end
|
||||
ctx.content = ctx.content .. text
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = ctx.content,
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.content_uuid,
|
||||
})
|
||||
ctx.content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:add_thinking_message(ctx, text, state, opts)
|
||||
if ctx.reasonging_content == nil then ctx.reasonging_content = "" end
|
||||
ctx.reasonging_content = ctx.reasonging_content .. text
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "thinking",
|
||||
thinking = ctx.reasonging_content,
|
||||
signature = "",
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = ctx.reasonging_content_uuid,
|
||||
})
|
||||
ctx.reasonging_content_uuid = msg.uuid
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:add_tool_use_message(tool_use, state, opts)
|
||||
local jsn = nil
|
||||
if state == "generated" then jsn = vim.json.decode(tool_use.input_json) end
|
||||
local msg = HistoryMessage:new({
|
||||
role = "assistant",
|
||||
content = {
|
||||
{
|
||||
type = "tool_use",
|
||||
name = tool_use.name,
|
||||
id = tool_use.id,
|
||||
input = jsn or {},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
state = state,
|
||||
uuid = tool_use.uuid,
|
||||
})
|
||||
tool_use.uuid = msg.uuid
|
||||
tool_use.state = state
|
||||
if opts.on_messages_add then opts.on_messages_add({ msg }) end
|
||||
end
|
||||
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
if data_stream:match('"%[DONE%]":') then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
end
|
||||
if data_stream:match('"delta":') then
|
||||
---@type AvanteOpenAIChatResponse
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
if jsn.choices and jsn.choices[1] then
|
||||
local choice = jsn.choices[1]
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
if choice.delta.content and choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end
|
||||
opts.on_stop({ reason = "complete" })
|
||||
elseif choice.finish_reason == "tool_calls" then
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
usage = jsn.usage,
|
||||
tool_use_list = ctx.tool_use_list,
|
||||
})
|
||||
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
opts.on_chunk("<think>\n")
|
||||
if not data_stream:match('"delta":') then return end
|
||||
---@type AvanteOpenAIChatResponse
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
if not jsn.choices or not jsn.choices[1] then return end
|
||||
local choice = jsn.choices[1]
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
if choice.delta.content and choice.delta.content ~= vim.NIL then
|
||||
self:add_text_message(ctx, choice.delta.content, "generated", opts)
|
||||
opts.on_chunk(choice.delta.content)
|
||||
end
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
elseif choice.finish_reason == "tool_calls" then
|
||||
self:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({
|
||||
reason = "tool_use",
|
||||
-- tool_use_list = ctx.tool_use_list,
|
||||
usage = jsn.usage,
|
||||
})
|
||||
elseif choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning_content
|
||||
self:add_thinking_message(ctx, choice.delta.reasoning_content, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning_content) end
|
||||
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning
|
||||
self:add_thinking_message(ctx, choice.delta.reasoning, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning) end
|
||||
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
|
||||
for _, tool_call in ipairs(choice.delta.tool_calls) do
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
if tool_call.index > 0 and ctx.tool_use_list[tool_call.index] then
|
||||
local prev_tool_use = ctx.tool_use_list[tool_call.index]
|
||||
self:add_tool_use_message(prev_tool_use, "generated", opts)
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning_content
|
||||
opts.on_chunk(choice.delta.reasoning_content)
|
||||
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
opts.on_chunk("<think>\n")
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning
|
||||
opts.on_chunk(choice.delta.reasoning)
|
||||
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
|
||||
for _, tool_call in ipairs(choice.delta.tool_calls) do
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
if not ctx.tool_use_list[tool_call.index + 1] then
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
if opts.on_partial_tool_use then
|
||||
opts.on_partial_tool_use({
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
partial_json = {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
if opts.on_partial_tool_use then
|
||||
local parser = StreamingJsonParser:new()
|
||||
local partial_json = parser:parse(tool_use.input_json)
|
||||
opts.on_partial_tool_use({
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
partial_json = partial_json or {},
|
||||
state = "generating",
|
||||
})
|
||||
end
|
||||
end
|
||||
end
|
||||
elseif choice.delta.content then
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
ctx.returned_think_end_tag = true
|
||||
if
|
||||
ctx.last_think_content
|
||||
and ctx.last_think_content ~= vim.NIL
|
||||
and ctx.last_think_content:sub(-1) ~= "\n"
|
||||
then
|
||||
opts.on_chunk("\n</think>\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n")
|
||||
end
|
||||
end
|
||||
if choice.delta.content ~= vim.NIL then opts.on_chunk(choice.delta.content) end
|
||||
local tool_use = {
|
||||
name = tool_call["function"].name,
|
||||
id = tool_call.id,
|
||||
input_json = "",
|
||||
}
|
||||
ctx.tool_use_list[tool_call.index + 1] = tool_use
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
else
|
||||
local tool_use = ctx.tool_use_list[tool_call.index + 1]
|
||||
tool_use.input_json = tool_use.input_json .. tool_call["function"].arguments
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
end
|
||||
end
|
||||
elseif choice.delta.content then
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
ctx.returned_think_end_tag = true
|
||||
if opts.on_chunk then
|
||||
if ctx.last_think_content and ctx.last_think_content ~= vim.NIL and ctx.last_think_content:sub(-1) ~= "\n" then
|
||||
opts.on_chunk("\n</think>\n")
|
||||
else
|
||||
opts.on_chunk("</think>\n")
|
||||
end
|
||||
end
|
||||
self:add_thinking_message(ctx, "", "generated", opts)
|
||||
end
|
||||
if choice.delta.content ~= vim.NIL then
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.content) end
|
||||
self:add_text_message(ctx, choice.delta.content, "generating", opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user