diff --git a/lua/avante/history/init.lua b/lua/avante/history/init.lua index 524e10b..2089934 100644 --- a/lua/avante/history/init.lua +++ b/lua/avante/history/init.lua @@ -14,10 +14,7 @@ function M.get_history_messages(history) local messages = {} for _, entry in ipairs(history.entries or {}) do if entry.request and entry.request ~= "" then - local message = Message:new({ - role = "user", - content = entry.request, - }, { + local message = Message:new("user", entry.request, { timestamp = entry.timestamp, is_user_submission = true, visible = entry.visible, @@ -27,10 +24,7 @@ function M.get_history_messages(history) table.insert(messages, message) end if entry.response and entry.response ~= "" then - local message = Message:new({ - role = "assistant", - content = entry.response, - }, { + local message = Message:new("assistant", entry.response, { timestamp = entry.timestamp, visible = entry.visible, }) diff --git a/lua/avante/history/message.lua b/lua/avante/history/message.lua index ffc91b2..eea5c79 100644 --- a/lua/avante/history/message.lua +++ b/lua/avante/history/message.lua @@ -18,10 +18,13 @@ M.__index = M ---@field just_for_display? boolean ---@field visible? boolean --- ----@param message AvanteLLMMessage +---@param role "user" | "assistant" +---@param content AvanteLLMMessageContentItem ---@param opts? avante.HistoryMessage.Opts ---@return avante.HistoryMessage -function M:new(message, opts) +function M:new(role, content, opts) + ---@type AvanteLLMMessage + local message = { role = role, content = type(content) == "string" and content or { content } } local obj = { message = message, uuid = Utils.uuid(), @@ -36,20 +39,17 @@ end ---Creates a new instance of synthetic (dummy) history message ---@param role "assistant" | "user" ----@param item AvanteLLMMessageContentItem | string +---@param item AvanteLLMMessageContentItem ---@return avante.HistoryMessage -function M:new_synthetic(role, item) - local content = type(item) == "string" and item or { item } - return M:new({ role = role, content = content }, { is_dummy = true }) -end +function M:new_synthetic(role, item) return M:new(role, item, { is_dummy = true }) end ---Creates a new instance of synthetic (dummy) history message attributed to the assistant ----@param item AvanteLLMMessageContentItem | string +---@param item AvanteLLMMessageContentItem ---@return avante.HistoryMessage function M:new_assistant_synthetic(item) return M:new_synthetic("assistant", item) end ---Creates a new instance of synthetic (dummy) history message attributed to the user ----@param item AvanteLLMMessageContentItem | string +---@param item AvanteLLMMessageContentItem ---@return avante.HistoryMessage function M:new_user_synthetic(item) return M:new_synthetic("user", item) end diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 5eb3489..302321b 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -779,17 +779,12 @@ function M._stream(opts) ---@type avante.HistoryMessage[] local messages = {} for _, tool_result in ipairs(tool_results) do - messages[#messages + 1] = History.Message:new({ - role = "user", - content = { - { - type = "tool_result", - tool_use_id = tool_result.tool_use_id, - content = tool_result.content, - is_error = tool_result.is_error, - is_user_declined = tool_result.is_user_declined, - }, - }, + messages[#messages + 1] = History.Message:new("user", { + type = "tool_result", + tool_use_id = tool_result.tool_use_id, + content = tool_result.content, + is_error = tool_result.is_error, + is_user_declined = tool_result.is_user_declined, }) end if opts.on_messages_add then opts.on_messages_add(messages) end @@ -822,12 +817,10 @@ function M._stream(opts) -- Special handling for cancellation signal from tools if error == LLMToolHelpers.CANCEL_TOKEN then Utils.debug("Tool execution was cancelled by user") - if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user during tool execution.]*\n") end + local cancelled_text = "\n*[Request cancelled by user during tool execution.]*\n" + if opts.on_chunk then opts.on_chunk(cancelled_text) end if opts.on_messages_add then - local message = History.Message:new({ - role = "assistant", - content = "\n*[Request cancelled by user during tool execution.]*\n", - }, { + local message = History.Message:new("assistant", cancelled_text, { just_for_display = true, }) opts.on_messages_add({ message }) @@ -878,12 +871,10 @@ function M._stream(opts) if result ~= nil or error ~= nil then return handle_tool_result(result, error) end end if stop_opts.reason == "cancelled" then - if opts.on_chunk then opts.on_chunk("\n*[Request cancelled by user.]*\n") end + local cancelled_text = "\n*[Request cancelled by user.]*\n" + if opts.on_chunk then opts.on_chunk(cancelled_text) end if opts.on_messages_add then - local message = History.Message:new({ - role = "assistant", - content = "\n*[Request cancelled by user.]*\n", - }, { + local message = History.Message:new("assistant", cancelled_text, { just_for_display = true, }) opts.on_messages_add({ message }) @@ -922,19 +913,21 @@ function M._stream(opts) Utils.debug("user reminder count", user_reminder_count) local message if #unfinished_todos > 0 then - message = History.Message:new({ - role = "user", - content = "You should use tool calls to answer the question, for example, use update_todo_status if the task step is done or cancelled.", - }, { - visible = false, - }) + message = History.Message:new( + "user", + "You should use tool calls to answer the question, for example, use update_todo_status if the task step is done or cancelled.", + { + visible = false, + } + ) else - message = History.Message:new({ - role = "user", - content = "You should use tool calls to answer the question, for example, use attempt_completion if the job is done.", - }, { - visible = false, - }) + message = History.Message:new( + "user", + "You should use tool calls to answer the question, for example, use attempt_completion if the job is done.", + { + visible = false, + } + ) end opts.on_messages_add({ message }) local new_opts = vim.tbl_deep_extend("force", opts, { @@ -958,12 +951,13 @@ function M._stream(opts) end if stop_opts.reason == "rate_limit" then local message = opts.on_messages_add - and History.Message:new({ - role = "assistant", - content = "", -- Actual content will be set below - }, { - just_for_display = true, - }) + and History.Message:new( + "assistant", + "", -- Actual content will be set below + { + just_for_display = true, + } + ) local timer = vim.loop.new_timer() if timer then diff --git a/lua/avante/llm_tools/dispatch_agent.lua b/lua/avante/llm_tools/dispatch_agent.lua index 06c7a98..cc248f9 100644 --- a/lua/avante/llm_tools/dispatch_agent.lua +++ b/lua/avante/llm_tools/dispatch_agent.lua @@ -263,10 +263,7 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub .. elapsed_time .. "s)" if session_ctx.on_messages_add then - local message = History.Message:new({ - role = "assistant", - content = "\n\n" .. summary, - }, { + local message = History.Message:new("assistant", "\n\n" .. summary, { just_for_display = true, }) session_ctx.on_messages_add({ message }) diff --git a/lua/avante/providers/claude.lua b/lua/avante/providers/claude.lua index af686f8..26fb8ca 100644 --- a/lua/avante/providers/claude.lua +++ b/lua/avante/providers/claude.lua @@ -161,6 +161,24 @@ function M:parse_response(ctx, data_stream, event_state, opts) end end if ctx.content_blocks == nil then ctx.content_blocks = {} end + + ---@param content AvanteLLMMessageContentItem + ---@param uuid? string + ---@return avante.HistoryMessage + local function new_assistant_message(content, uuid) + assert( + event_state == "content_block_start" + or event_state == "content_block_delta" + or event_state == "content_block_stop", + "called with unexpected event_state: " .. event_state + ) + return HistoryMessage:new("assistant", content, { + state = event_state == "content_block_stop" and "generated" or "generating", + turn_id = ctx.turn_id, + uuid = uuid, + }) + end + if event_state == "message_start" then local ok, jsn = pcall(vim.json.decode, data_stream) if not ok then return end @@ -172,55 +190,33 @@ function M:parse_response(ctx, data_stream, event_state, opts) content_block.stoppped = false ctx.content_blocks[jsn.index + 1] = content_block if content_block.type == "text" then - local msg = HistoryMessage:new({ - role = "assistant", - content = content_block.text, - }, { - state = "generating", - turn_id = ctx.turn_id, - }) + local msg = new_assistant_message(content_block.text) content_block.uuid = msg.uuid if opts.on_messages_add then opts.on_messages_add({ msg }) end - end - if content_block.type == "thinking" then + elseif content_block.type == "thinking" then if opts.on_chunk then opts.on_chunk("\n") end if opts.on_messages_add then - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "thinking", - thinking = content_block.thinking, - signature = content_block.signature, - }, - }, - }, { - state = "generating", - turn_id = ctx.turn_id, + local msg = new_assistant_message({ + type = "thinking", + thinking = content_block.thinking, + signature = content_block.signature, }) content_block.uuid = msg.uuid opts.on_messages_add({ msg }) end - end - if content_block.type == "tool_use" and opts.on_messages_add then - local incomplete_json = JsonParser.parse(content_block.input_json) - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - name = content_block.name, - id = content_block.id, - input = incomplete_json or {}, - }, - }, - }, { - state = "generating", - turn_id = ctx.turn_id, - }) - content_block.uuid = msg.uuid - opts.on_messages_add({ msg }) - -- opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) + elseif content_block.type == "tool_use" then + if opts.on_messages_add then + local incomplete_json = JsonParser.parse(content_block.input_json) + local msg = new_assistant_message({ + type = "tool_use", + name = content_block.name, + id = content_block.id, + input = incomplete_json or {}, + }) + content_block.uuid = msg.uuid + opts.on_messages_add({ msg }) + -- opts.on_stop({ reason = "tool_use", streaming_tool_use = true }) + end end elseif event_state == "content_block_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) @@ -233,33 +229,21 @@ function M:parse_response(ctx, data_stream, event_state, opts) elseif jsn.delta.type == "thinking_delta" then content_block.thinking = content_block.thinking .. jsn.delta.thinking if opts.on_chunk then opts.on_chunk(jsn.delta.thinking) end - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "thinking", - thinking = content_block.thinking, - signature = content_block.signature, - }, - }, - }, { - state = "generating", - uuid = content_block.uuid, - turn_id = ctx.turn_id, - }) - if opts.on_messages_add then opts.on_messages_add({ msg }) end + if opts.on_messages_add then + local msg = new_assistant_message({ + type = "thinking", + thinking = content_block.thinking, + signature = content_block.signature, + }, content_block.uuid) + opts.on_messages_add({ msg }) + end elseif jsn.delta.type == "text_delta" then content_block.text = content_block.text .. jsn.delta.text if opts.on_chunk then opts.on_chunk(jsn.delta.text) end - local msg = HistoryMessage:new({ - role = "assistant", - content = content_block.text, - }, { - state = "generating", - uuid = content_block.uuid, - turn_id = ctx.turn_id, - }) - if opts.on_messages_add then opts.on_messages_add({ msg }) end + if opts.on_messages_add then + local msg = new_assistant_message(content_block.text, content_block.uuid) + opts.on_messages_add({ msg }) + end elseif jsn.delta.type == "signature_delta" then if ctx.content_blocks[jsn.index + 1].signature == nil then ctx.content_blocks[jsn.index + 1].signature = "" end ctx.content_blocks[jsn.index + 1].signature = ctx.content_blocks[jsn.index + 1].signature .. jsn.delta.signature @@ -270,36 +254,11 @@ function M:parse_response(ctx, data_stream, event_state, opts) local content_block = ctx.content_blocks[jsn.index + 1] content_block.stoppped = true if content_block.type == "text" then - local msg = HistoryMessage:new({ - role = "assistant", - content = content_block.text, - }, { - state = "generated", - uuid = content_block.uuid, - turn_id = ctx.turn_id, - }) - if opts.on_messages_add then opts.on_messages_add({ msg }) end - end - if content_block.type == "tool_use" then - local complete_json = vim.json.decode(content_block.input_json) - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - name = content_block.name, - id = content_block.id, - input = complete_json or {}, - }, - }, - }, { - state = "generated", - uuid = content_block.uuid, - turn_id = ctx.turn_id, - }) - if opts.on_messages_add then opts.on_messages_add({ msg }) end - end - if content_block.type == "thinking" then + if opts.on_messages_add then + local msg = new_assistant_message(content_block.text, content_block.uuid) + opts.on_messages_add({ msg }) + end + elseif content_block.type == "thinking" then if opts.on_chunk then if content_block.thinking and content_block.thinking ~= vim.NIL and content_block.thinking:sub(-1) ~= "\n" then opts.on_chunk("\n\n\n") @@ -307,21 +266,25 @@ function M:parse_response(ctx, data_stream, event_state, opts) opts.on_chunk("\n\n") end end - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "thinking", - thinking = content_block.thinking, - signature = content_block.signature, - }, - }, - }, { - state = "generated", - uuid = content_block.uuid, - turn_id = ctx.turn_id, - }) - if opts.on_messages_add then opts.on_messages_add({ msg }) end + if opts.on_messages_add then + local msg = new_assistant_message({ + type = "thinking", + thinking = content_block.thinking, + signature = content_block.signature, + }, content_block.uuid) + opts.on_messages_add({ msg }) + end + elseif content_block.type == "tool_use" then + if opts.on_messages_add then + local complete_json = vim.json.decode(content_block.input_json) + local msg = new_assistant_message({ + type = "tool_use", + name = content_block.name, + id = content_block.id, + input = complete_json or {}, + }, content_block.uuid) + opts.on_messages_add({ msg }) + end end elseif event_state == "message_delta" then local ok, jsn = pcall(vim.json.decode, data_stream) diff --git a/lua/avante/providers/ollama.lua b/lua/avante/providers/ollama.lua index fe49589..650bc4b 100644 --- a/lua/avante/providers/ollama.lua +++ b/lua/avante/providers/ollama.lua @@ -135,27 +135,24 @@ function M:is_disable_stream() return false end ---@param tool_calls avante.OllamaToolCall[] ---@param opts AvanteLLMStreamOptions function M:add_tool_use_messages(tool_calls, opts) - local msgs = {} - for _, tool_call in ipairs(tool_calls) do - local id = Utils.uuid() - local func = tool_call["function"] - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - name = func.name, - id = id, - input = func.arguments, - }, - }, - }, { - state = "generated", - uuid = id, - }) - table.insert(msgs, msg) + if opts.on_messages_add then + local msgs = {} + for _, tool_call in ipairs(tool_calls) do + local id = Utils.uuid() + local func = tool_call["function"] + local msg = HistoryMessage:new("assistant", { + type = "tool_use", + name = func.name, + id = id, + input = func.arguments, + }, { + state = "generated", + uuid = id, + }) + table.insert(msgs, msg) + end + opts.on_messages_add(msgs) end - if opts.on_messages_add then opts.on_messages_add(msgs) end end function M:parse_stream_data(ctx, data, opts) diff --git a/lua/avante/providers/openai.lua b/lua/avante/providers/openai.lua index 9f7fbe2..5b4d183 100644 --- a/lua/avante/providers/openai.lua +++ b/lua/avante/providers/openai.lua @@ -231,10 +231,7 @@ function M:add_text_message(ctx, text, state, opts) local content = ctx.content:gsub("", ""):gsub("", ""):gsub("", ""):gsub("", "") ctx.content = content - local msg = HistoryMessage:new({ - role = "assistant", - content = ctx.content, - }, { + local msg = HistoryMessage:new("assistant", ctx.content, { state = state, uuid = ctx.content_uuid, original_content = ctx.content, @@ -299,16 +296,11 @@ function M:add_text_message(ctx, text, state, opts) has_tool_use = true local msg_uuid = ctx.content_uuid .. "-" .. idx local tool_use_id = msg_uuid - local msg_ = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - name = item.tool_name, - id = tool_use_id, - input = input, - }, - }, + local msg_ = HistoryMessage:new("assistant", { + type = "tool_use", + name = item.tool_name, + id = tool_use_id, + input = input, }, { state = state, uuid = msg_uuid, @@ -342,15 +334,10 @@ end function M:add_thinking_message(ctx, text, state, opts) if ctx.reasonging_content == nil then ctx.reasonging_content = "" end ctx.reasonging_content = ctx.reasonging_content .. text - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "thinking", - thinking = ctx.reasonging_content, - signature = "", - }, - }, + local msg = HistoryMessage:new("assistant", { + type = "thinking", + thinking = ctx.reasonging_content, + signature = "", }, { state = state, uuid = ctx.reasonging_content_uuid, @@ -362,16 +349,11 @@ end function M:add_tool_use_message(ctx, tool_use, state, opts) local jsn = JsonParser.parse(tool_use.input_json) - local msg = HistoryMessage:new({ - role = "assistant", - content = { - { - type = "tool_use", - name = tool_use.name, - id = tool_use.id, - input = jsn or {}, - }, - }, + local msg = HistoryMessage:new("assistant", { + type = "tool_use", + name = tool_use.name, + id = tool_use.id, + input = jsn or {}, }, { state = state, uuid = tool_use.uuid, diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 51851af..f51c522 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -2094,6 +2094,7 @@ function Sidebar:add_history_messages(messages) end) end +-- FIXME: this is used by external plugin users ---@param messages AvanteLLMMessage | AvanteLLMMessage[] ---@param options {visible?: boolean} function Sidebar:add_chat_history(messages, options) @@ -2102,25 +2103,18 @@ function Sidebar:add_chat_history(messages, options) local is_first_user = true local history_messages = {} for _, message in ipairs(messages) do - local content = message.content - if message.role == "system" and type(content) == "string" then - ---@cast content string - self.chat_history.system_prompt = content - goto continue - end - local history_message = History.Message:new(message) - if message.role == "user" and is_first_user then - is_first_user = false - history_message.is_user_submission = true - history_message.provider = Config.provider - history_message.model = Config.get_provider_config(Config.provider).model - end - table.insert(history_messages, history_message) - ::continue:: - end - if options.visible ~= nil then - for _, history_message in ipairs(history_messages) do - history_message.visible = options.visible + local role = message.role + if role == "system" and type(message.content) == "string" then + self.chat_history.system_prompt = message.content --[[@as string]] + else + ---@type AvanteLLMMessageContentItem + local content = type(message.content) ~= "table" and message.content or message.content[1] + local msg_opts = { visible = options.visible } + if role == "user" and is_first_user then + msg_opts.is_user_submission = true + is_first_user = false + end + table.insert(history_messages, History.Message:new(role, content, msg_opts)) end end self:add_history_messages(history_messages) @@ -2415,7 +2409,7 @@ function Sidebar:create_input_container() if self.is_generating then self:add_history_messages({ - History.Message:new({ role = "user", content = request }), + History.Message:new("user", request), }) return end @@ -2553,10 +2547,7 @@ function Sidebar:create_input_container() local msg_content = stop_opts.error if type(msg_content) ~= "string" then msg_content = vim.inspect(msg_content) end self:add_history_messages({ - History.Message:new({ - role = "assistant", - content = "\n\nError: " .. msg_content, - }, { + History.Message:new("assistant", "\n\nError: " .. msg_content, { just_for_display = true, }), }) @@ -2584,10 +2575,7 @@ function Sidebar:create_input_container() if request and request ~= "" then self:add_history_messages({ - History.Message:new({ - role = "user", - content = request, - }, { + History.Message:new("user", request, { is_user_submission = true, selected_filepaths = selected_filepaths, selected_code = selected_code, @@ -2812,6 +2800,7 @@ function Sidebar:create_input_container() }) end +-- FIXME: this is used by external plugin users ---@param value string function Sidebar:set_input_value(value) if not self.containers.input then return end diff --git a/lua/avante/suggestion.lua b/lua/avante/suggestion.lua index 47b9241..6731553 100644 --- a/lua/avante/suggestion.lua +++ b/lua/avante/suggestion.lua @@ -231,7 +231,10 @@ L5: pass }, } - local history_messages = vim.iter(llm_messages):map(function(msg) return HistoryMessage:new(msg) end):totable() + local history_messages = vim + .iter(llm_messages) + :map(function(msg) return HistoryMessage:new(msg.role, msg.content) end) + :totable() local diagnostics = Utils.lsp.get_diagnostics(bufnr)