fix: o1 model (#1992)

This commit is contained in:
yetone
2025-05-06 19:32:55 +08:00
committed by GitHub
parent 0b78b58760
commit 60897ee9a6
7 changed files with 109 additions and 66 deletions

View File

@@ -612,13 +612,22 @@ function M.has_provider(provider_name) return vim.list_contains(M.provider_names
---@param provider_name avante.ProviderName
function M.get_provider_config(provider_name)
if not M.has_provider(provider_name) then error("No provider found: " .. provider_name, 2) end
if M._options[provider_name] ~= nil then
return vim.deepcopy(M._options[provider_name], true)
elseif M.vendors and M.vendors[provider_name] ~= nil then
return vim.deepcopy(M.vendors[provider_name], true)
else
error("Failed to find provider: " .. provider_name, 2)
local found = false
local config = {}
if M.vendors and M.vendors[provider_name] ~= nil then
found = true
config = vim.tbl_deep_extend("force", config, vim.deepcopy(M.vendors[provider_name], true))
end
if M._options[provider_name] ~= nil then
found = true
config = vim.tbl_deep_extend("force", config, vim.deepcopy(M._options[provider_name], true))
end
if not found then error("Failed to find provider: " .. provider_name, 2) end
return config
end
M.BASE_PROVIDER_KEYS = {

View File

@@ -272,6 +272,7 @@ local function set_cursor(position, side)
if not position then return end
local target = side == SIDES.OURS and position.current or position.incoming
api.nvim_win_set_cursor(0, { target.range_start + 1, 0 })
vim.cmd("normal! zz")
end
local show_keybinding_hint_extmark_id = nil

View File

@@ -335,9 +335,10 @@ function M.generate_prompts(opts)
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
end
local dropped_history_messages = {}
if opts.prompt_opts and opts.prompt_opts.dropped_history_messages then
dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages)
local pending_compaction_history_messages = {}
if opts.prompt_opts and opts.prompt_opts.pending_compaction_history_messages then
pending_compaction_history_messages =
vim.list_extend(pending_compaction_history_messages, opts.prompt_opts.pending_compaction_history_messages)
end
local cleaned_history_messages = history_messages
@@ -350,6 +351,7 @@ function M.generate_prompts(opts)
if Config.history.max_tokens > 0 then
remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens)
end
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
local retained_history_messages = {}
for i = #cleaned_history_messages, 1, -1 do
@@ -368,11 +370,11 @@ function M.generate_prompts(opts)
vim.list_slice(cleaned_history_messages, #cleaned_history_messages - 1, #cleaned_history_messages)
end
dropped_history_messages =
pending_compaction_history_messages =
vim.list_slice(cleaned_history_messages, 1, #cleaned_history_messages - #retained_history_messages)
dropped_history_messages = vim
.iter(dropped_history_messages)
pending_compaction_history_messages = vim
.iter(pending_compaction_history_messages)
:filter(function(msg) return msg.is_dummy ~= true end)
:totable()
@@ -411,7 +413,7 @@ function M.generate_prompts(opts)
messages = messages,
image_paths = image_paths,
tools = tools,
dropped_history_messages = dropped_history_messages,
pending_compaction_history_messages = pending_compaction_history_messages,
}
end
@@ -464,13 +466,18 @@ function M.curl(opts)
end
local data_match = line:match("^data:%s*(.+)$")
if data_match then
response_body = ""
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
else
response_body = response_body .. line
local ok, jsn = pcall(vim.json.decode, response_body)
if ok then
if jsn.error then
handler_opts.on_stop({ reason = "error", error = jsn.error })
else
provider:parse_response(resp_ctx, response_body, current_event_state, handler_opts)
end
response_body = ""
if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end
end
end
end
@@ -485,16 +492,13 @@ function M.curl(opts)
local temp_file = fn.tempname()
local curl_body_file = temp_file .. "-request-body.json"
local resp_body_file = temp_file .. "-response-body.json"
local resp_body_file = temp_file .. "-response-body.txt"
local headers_file = temp_file .. "-response-headers.txt"
local json_content = vim.json.encode(spec.body)
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
Utils.debug("curl request body file:", curl_body_file)
Utils.debug("curl response body file:", resp_body_file)
local headers_file = temp_file .. "-headers.txt"
Utils.debug("curl headers file:", headers_file)
local function cleanup()
@@ -609,7 +613,6 @@ function M.curl(opts)
end
end
end
Utils.debug("result", result)
local retry_after = 10
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
@@ -698,11 +701,11 @@ function M._stream(opts)
local prompt_opts = M.generate_prompts(opts)
if
prompt_opts.dropped_history_messages
and #prompt_opts.dropped_history_messages > 0
prompt_opts.pending_compaction_history_messages
and #prompt_opts.pending_compaction_history_messages > 0
and opts.on_memory_summarize
then
opts.on_memory_summarize(prompt_opts.dropped_history_messages)
opts.on_memory_summarize(prompt_opts.pending_compaction_history_messages)
return
end

View File

@@ -166,12 +166,16 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
end,
}
local function on_memory_summarize(dropped_history_messages)
Llm.summarize_memory(memory_content, dropped_history_messages or {}, function(memory)
local function on_memory_summarize(pending_compaction_history_messages)
Llm.summarize_memory(memory_content, pending_compaction_history_messages or {}, function(memory)
if memory then stream_options.memory = memory.content end
local new_history_messages = {}
for _, msg in ipairs(history_messages) do
if vim.iter(dropped_history_messages):find(function(dropped_msg) return dropped_msg.uuid == msg.uuid end) then
if
vim
.iter(pending_compaction_history_messages)
:find(function(pending_compaction_msg) return pending_compaction_msg.uuid == msg.uuid end)
then
goto continue
end
table.insert(new_history_messages, msg)

View File

@@ -245,30 +245,37 @@ function M:parse_response(ctx, data_stream, _, opts)
opts.on_stop({ reason = "complete" })
return
end
if not data_stream:match('"delta":') then return end
---@type AvanteOpenAIChatResponse
if data_stream == "[DONE]" then return end
local jsn = vim.json.decode(data_stream)
if not jsn.choices or not jsn.choices[1] then return end
---@cast jsn AvanteOpenAIChatResponse
if not jsn.choices then return end
local choice = jsn.choices[1]
if choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
if not choice then return end
local delta = choice.delta
if not delta then
local provider_conf = Providers.parse_config(self)
if provider_conf.model:match("o1") then delta = choice.message end
end
if not delta then return end
if delta.reasoning_content and delta.reasoning_content ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
if opts.on_chunk then opts.on_chunk("<think>\n") end
end
ctx.last_think_content = choice.delta.reasoning_content
self:add_thinking_message(ctx, choice.delta.reasoning_content, "generating", opts)
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning_content) end
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
ctx.last_think_content = delta.reasoning_content
self:add_thinking_message(ctx, delta.reasoning_content, "generating", opts)
if opts.on_chunk then opts.on_chunk(delta.reasoning_content) end
elseif delta.reasoning and delta.reasoning ~= vim.NIL then
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
ctx.returned_think_start_tag = true
if opts.on_chunk then opts.on_chunk("<think>\n") end
end
ctx.last_think_content = choice.delta.reasoning
self:add_thinking_message(ctx, choice.delta.reasoning, "generating", opts)
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning) end
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
ctx.last_think_content = delta.reasoning
self:add_thinking_message(ctx, delta.reasoning, "generating", opts)
if opts.on_chunk then opts.on_chunk(delta.reasoning) end
elseif delta.tool_calls and delta.tool_calls ~= vim.NIL then
local choice_index = choice.index or 0
for idx, tool_call in ipairs(choice.delta.tool_calls) do
for idx, tool_call in ipairs(delta.tool_calls) do
--- In Gemini's so-called OpenAI Compatible API, tool_call.index is nil, which is quite absurd! Therefore, a compatibility fix is needed here.
if tool_call.index == nil then tool_call.index = choice_index + idx - 1 end
if not ctx.tool_use_list then ctx.tool_use_list = {} end
@@ -290,7 +297,7 @@ function M:parse_response(ctx, data_stream, _, opts)
self:add_tool_use_message(tool_use, "generating", opts)
end
end
elseif choice.delta.content then
elseif delta.content then
if
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
then
@@ -304,9 +311,9 @@ function M:parse_response(ctx, data_stream, _, opts)
end
self:add_thinking_message(ctx, "", "generated", opts)
end
if choice.delta.content ~= vim.NIL then
if opts.on_chunk then opts.on_chunk(choice.delta.content) end
self:add_text_message(ctx, choice.delta.content, "generating", opts)
if delta.content ~= vim.NIL then
if opts.on_chunk then opts.on_chunk(delta.content) end
self:add_text_message(ctx, delta.content, "generating", opts)
end
end
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then

View File

@@ -694,6 +694,11 @@ local function insert_conflict_contents(bufnr, snippets)
for _, snippet in ipairs(snippets) do
local start_line, end_line = unpack(snippet.range)
local first_line_content = lines[start_line]
local old_first_line_indentation = ""
if first_line_content then old_first_line_indentation = Utils.get_indentation(first_line_content) end
local result = {}
table.insert(result, "<<<<<<< HEAD")
for i = start_line, end_line do
@@ -703,6 +708,14 @@ local function insert_conflict_contents(bufnr, snippets)
local snippet_lines = vim.split(snippet.content, "\n")
if #snippet_lines > 0 then
local new_first_line_indentation = Utils.get_indentation(snippet_lines[1])
if #old_first_line_indentation > #new_first_line_indentation then
local line_indentation = old_first_line_indentation:sub(#new_first_line_indentation + 1)
snippet_lines = vim.iter(snippet_lines):map(function(line) return line_indentation .. line end):totable()
end
end
vim.list_extend(result, snippet_lines)
table.insert(result, ">>>>>>> Snippet")
@@ -1892,10 +1905,19 @@ function Sidebar:add_history_messages(messages)
end
self.chat_history.messages = history_messages
Path.history.save(self.code.bufnr, self.chat_history)
if self.chat_history.title == "untitled" and #messages > 0 then
if
self.chat_history.title == "untitled"
and #messages > 0
and messages[1].just_for_display ~= true
and messages[1].state == "generated"
then
self.chat_history.title = "generating..."
Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
self:reload_chat_history()
if title then self.chat_history.title = title end
if title then
self.chat_history.title = title
else
self.chat_history.title = "untitled"
end
Path.history.save(self.code.bufnr, self.chat_history)
end)
end
@@ -2371,14 +2393,7 @@ function Sidebar:create_input_container()
if Config.behaviour.auto_apply_diff_after_generation then self:apply(false) end
end, 0)
if self.chat_history.title == "untitled" then
Llm.summarize_chat_thread_title(request, function(title)
if title then self.chat_history.title = title end
Path.history.save(self.code.bufnr, self.chat_history)
end)
else
Path.history.save(self.code.bufnr, self.chat_history)
end
Path.history.save(self.code.bufnr, self.chat_history)
end
if request and request ~= "" then
@@ -2407,18 +2422,22 @@ function Sidebar:create_input_container()
session_ctx = {},
})
---@param dropped_history_messages avante.HistoryMessage[]
local function on_memory_summarize(dropped_history_messages)
---@param pending_compaction_history_messages avante.HistoryMessage[]
local function on_memory_summarize(pending_compaction_history_messages)
local history_memory = self.chat_history.memory
Llm.summarize_memory(history_memory and history_memory.content, dropped_history_messages, function(memory)
if memory then
self.chat_history.memory = memory
Path.history.save(self.code.bufnr, self.chat_history)
stream_options.memory = memory.content
Llm.summarize_memory(
history_memory and history_memory.content,
pending_compaction_history_messages,
function(memory)
if memory then
self.chat_history.memory = memory
Path.history.save(self.code.bufnr, self.chat_history)
stream_options.memory = memory.content
end
stream_options.history_messages = self:get_history_messages_for_api()
Llm.stream(stream_options)
end
stream_options.history_messages = self:get_history_messages_for_api()
Llm.stream(stream_options)
end)
)
end
stream_options.on_memory_summarize = on_memory_summarize

View File

@@ -114,7 +114,7 @@ vim.g.avante_login = vim.g.avante_login
---@field messages AvanteLLMMessage[]
---@field image_paths? string[]
---@field tools? AvanteLLMTool[]
---@field dropped_history_messages? avante.HistoryMessage[]
---@field pending_compaction_history_messages? AvanteLLMMessage[]
---
---@class AvanteGeminiMessage
---@field role "user"
@@ -356,7 +356,7 @@ vim.g.avante_login = vim.g.avante_login
---@field tool_result? AvanteLLMToolResult
---@field tool_use? AvanteLLMToolUse
---
---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: avante.HistoryMessage[]): nil
---@alias AvanteLLMMemorySummarizeCallback fun(pending_compaction_history_messages: avante.HistoryMessage[]): nil
---
---@alias AvanteLLMToolUseState "generating" | "generated" | "running" | "succeeded" | "failed"
---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking"