fix: o1 model (#1992)
This commit is contained in:
@@ -612,13 +612,22 @@ function M.has_provider(provider_name) return vim.list_contains(M.provider_names
|
||||
---@param provider_name avante.ProviderName
|
||||
function M.get_provider_config(provider_name)
|
||||
if not M.has_provider(provider_name) then error("No provider found: " .. provider_name, 2) end
|
||||
if M._options[provider_name] ~= nil then
|
||||
return vim.deepcopy(M._options[provider_name], true)
|
||||
elseif M.vendors and M.vendors[provider_name] ~= nil then
|
||||
return vim.deepcopy(M.vendors[provider_name], true)
|
||||
else
|
||||
error("Failed to find provider: " .. provider_name, 2)
|
||||
local found = false
|
||||
local config = {}
|
||||
|
||||
if M.vendors and M.vendors[provider_name] ~= nil then
|
||||
found = true
|
||||
config = vim.tbl_deep_extend("force", config, vim.deepcopy(M.vendors[provider_name], true))
|
||||
end
|
||||
|
||||
if M._options[provider_name] ~= nil then
|
||||
found = true
|
||||
config = vim.tbl_deep_extend("force", config, vim.deepcopy(M._options[provider_name], true))
|
||||
end
|
||||
|
||||
if not found then error("Failed to find provider: " .. provider_name, 2) end
|
||||
|
||||
return config
|
||||
end
|
||||
|
||||
M.BASE_PROVIDER_KEYS = {
|
||||
|
||||
@@ -272,6 +272,7 @@ local function set_cursor(position, side)
|
||||
if not position then return end
|
||||
local target = side == SIDES.OURS and position.current or position.incoming
|
||||
api.nvim_win_set_cursor(0, { target.range_start + 1, 0 })
|
||||
vim.cmd("normal! zz")
|
||||
end
|
||||
|
||||
local show_keybinding_hint_extmark_id = nil
|
||||
|
||||
@@ -335,9 +335,10 @@ function M.generate_prompts(opts)
|
||||
remaining_tokens = remaining_tokens - Utils.tokens.calculate_tokens(message.content)
|
||||
end
|
||||
|
||||
local dropped_history_messages = {}
|
||||
if opts.prompt_opts and opts.prompt_opts.dropped_history_messages then
|
||||
dropped_history_messages = vim.list_extend(dropped_history_messages, opts.prompt_opts.dropped_history_messages)
|
||||
local pending_compaction_history_messages = {}
|
||||
if opts.prompt_opts and opts.prompt_opts.pending_compaction_history_messages then
|
||||
pending_compaction_history_messages =
|
||||
vim.list_extend(pending_compaction_history_messages, opts.prompt_opts.pending_compaction_history_messages)
|
||||
end
|
||||
|
||||
local cleaned_history_messages = history_messages
|
||||
@@ -350,6 +351,7 @@ function M.generate_prompts(opts)
|
||||
if Config.history.max_tokens > 0 then
|
||||
remaining_tokens = math.min(Config.history.max_tokens, remaining_tokens)
|
||||
end
|
||||
|
||||
-- Traverse the history in reverse, keeping only the latest history until the remaining tokens are exhausted and the first message role is "user"
|
||||
local retained_history_messages = {}
|
||||
for i = #cleaned_history_messages, 1, -1 do
|
||||
@@ -368,11 +370,11 @@ function M.generate_prompts(opts)
|
||||
vim.list_slice(cleaned_history_messages, #cleaned_history_messages - 1, #cleaned_history_messages)
|
||||
end
|
||||
|
||||
dropped_history_messages =
|
||||
pending_compaction_history_messages =
|
||||
vim.list_slice(cleaned_history_messages, 1, #cleaned_history_messages - #retained_history_messages)
|
||||
|
||||
dropped_history_messages = vim
|
||||
.iter(dropped_history_messages)
|
||||
pending_compaction_history_messages = vim
|
||||
.iter(pending_compaction_history_messages)
|
||||
:filter(function(msg) return msg.is_dummy ~= true end)
|
||||
:totable()
|
||||
|
||||
@@ -411,7 +413,7 @@ function M.generate_prompts(opts)
|
||||
messages = messages,
|
||||
image_paths = image_paths,
|
||||
tools = tools,
|
||||
dropped_history_messages = dropped_history_messages,
|
||||
pending_compaction_history_messages = pending_compaction_history_messages,
|
||||
}
|
||||
end
|
||||
|
||||
@@ -464,13 +466,18 @@ function M.curl(opts)
|
||||
end
|
||||
local data_match = line:match("^data:%s*(.+)$")
|
||||
if data_match then
|
||||
response_body = ""
|
||||
provider:parse_response(resp_ctx, data_match, current_event_state, handler_opts)
|
||||
else
|
||||
response_body = response_body .. line
|
||||
local ok, jsn = pcall(vim.json.decode, response_body)
|
||||
if ok then
|
||||
if jsn.error then
|
||||
handler_opts.on_stop({ reason = "error", error = jsn.error })
|
||||
else
|
||||
provider:parse_response(resp_ctx, response_body, current_event_state, handler_opts)
|
||||
end
|
||||
response_body = ""
|
||||
if jsn.error then handler_opts.on_stop({ reason = "error", error = jsn.error }) end
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -485,16 +492,13 @@ function M.curl(opts)
|
||||
|
||||
local temp_file = fn.tempname()
|
||||
local curl_body_file = temp_file .. "-request-body.json"
|
||||
local resp_body_file = temp_file .. "-response-body.json"
|
||||
local resp_body_file = temp_file .. "-response-body.txt"
|
||||
local headers_file = temp_file .. "-response-headers.txt"
|
||||
local json_content = vim.json.encode(spec.body)
|
||||
fn.writefile(vim.split(json_content, "\n"), curl_body_file)
|
||||
|
||||
Utils.debug("curl request body file:", curl_body_file)
|
||||
|
||||
Utils.debug("curl response body file:", resp_body_file)
|
||||
|
||||
local headers_file = temp_file .. "-headers.txt"
|
||||
|
||||
Utils.debug("curl headers file:", headers_file)
|
||||
|
||||
local function cleanup()
|
||||
@@ -609,7 +613,6 @@ function M.curl(opts)
|
||||
end
|
||||
end
|
||||
end
|
||||
Utils.debug("result", result)
|
||||
local retry_after = 10
|
||||
if headers_map["retry-after"] then retry_after = tonumber(headers_map["retry-after"]) or 10 end
|
||||
handler_opts.on_stop({ reason = "rate_limit", retry_after = retry_after })
|
||||
@@ -698,11 +701,11 @@ function M._stream(opts)
|
||||
local prompt_opts = M.generate_prompts(opts)
|
||||
|
||||
if
|
||||
prompt_opts.dropped_history_messages
|
||||
and #prompt_opts.dropped_history_messages > 0
|
||||
prompt_opts.pending_compaction_history_messages
|
||||
and #prompt_opts.pending_compaction_history_messages > 0
|
||||
and opts.on_memory_summarize
|
||||
then
|
||||
opts.on_memory_summarize(prompt_opts.dropped_history_messages)
|
||||
opts.on_memory_summarize(prompt_opts.pending_compaction_history_messages)
|
||||
return
|
||||
end
|
||||
|
||||
|
||||
@@ -166,12 +166,16 @@ When you're done, provide a clear and concise summary of what you found.]]):gsub
|
||||
end,
|
||||
}
|
||||
|
||||
local function on_memory_summarize(dropped_history_messages)
|
||||
Llm.summarize_memory(memory_content, dropped_history_messages or {}, function(memory)
|
||||
local function on_memory_summarize(pending_compaction_history_messages)
|
||||
Llm.summarize_memory(memory_content, pending_compaction_history_messages or {}, function(memory)
|
||||
if memory then stream_options.memory = memory.content end
|
||||
local new_history_messages = {}
|
||||
for _, msg in ipairs(history_messages) do
|
||||
if vim.iter(dropped_history_messages):find(function(dropped_msg) return dropped_msg.uuid == msg.uuid end) then
|
||||
if
|
||||
vim
|
||||
.iter(pending_compaction_history_messages)
|
||||
:find(function(pending_compaction_msg) return pending_compaction_msg.uuid == msg.uuid end)
|
||||
then
|
||||
goto continue
|
||||
end
|
||||
table.insert(new_history_messages, msg)
|
||||
|
||||
@@ -245,30 +245,37 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
return
|
||||
end
|
||||
if not data_stream:match('"delta":') then return end
|
||||
---@type AvanteOpenAIChatResponse
|
||||
if data_stream == "[DONE]" then return end
|
||||
local jsn = vim.json.decode(data_stream)
|
||||
if not jsn.choices or not jsn.choices[1] then return end
|
||||
---@cast jsn AvanteOpenAIChatResponse
|
||||
if not jsn.choices then return end
|
||||
local choice = jsn.choices[1]
|
||||
if choice.delta.reasoning_content and choice.delta.reasoning_content ~= vim.NIL then
|
||||
if not choice then return end
|
||||
local delta = choice.delta
|
||||
if not delta then
|
||||
local provider_conf = Providers.parse_config(self)
|
||||
if provider_conf.model:match("o1") then delta = choice.message end
|
||||
end
|
||||
if not delta then return end
|
||||
if delta.reasoning_content and delta.reasoning_content ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning_content
|
||||
self:add_thinking_message(ctx, choice.delta.reasoning_content, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning_content) end
|
||||
elseif choice.delta.reasoning and choice.delta.reasoning ~= vim.NIL then
|
||||
ctx.last_think_content = delta.reasoning_content
|
||||
self:add_thinking_message(ctx, delta.reasoning_content, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(delta.reasoning_content) end
|
||||
elseif delta.reasoning and delta.reasoning ~= vim.NIL then
|
||||
if ctx.returned_think_start_tag == nil or not ctx.returned_think_start_tag then
|
||||
ctx.returned_think_start_tag = true
|
||||
if opts.on_chunk then opts.on_chunk("<think>\n") end
|
||||
end
|
||||
ctx.last_think_content = choice.delta.reasoning
|
||||
self:add_thinking_message(ctx, choice.delta.reasoning, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.reasoning) end
|
||||
elseif choice.delta.tool_calls and choice.delta.tool_calls ~= vim.NIL then
|
||||
ctx.last_think_content = delta.reasoning
|
||||
self:add_thinking_message(ctx, delta.reasoning, "generating", opts)
|
||||
if opts.on_chunk then opts.on_chunk(delta.reasoning) end
|
||||
elseif delta.tool_calls and delta.tool_calls ~= vim.NIL then
|
||||
local choice_index = choice.index or 0
|
||||
for idx, tool_call in ipairs(choice.delta.tool_calls) do
|
||||
for idx, tool_call in ipairs(delta.tool_calls) do
|
||||
--- In Gemini's so-called OpenAI Compatible API, tool_call.index is nil, which is quite absurd! Therefore, a compatibility fix is needed here.
|
||||
if tool_call.index == nil then tool_call.index = choice_index + idx - 1 end
|
||||
if not ctx.tool_use_list then ctx.tool_use_list = {} end
|
||||
@@ -290,7 +297,7 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
self:add_tool_use_message(tool_use, "generating", opts)
|
||||
end
|
||||
end
|
||||
elseif choice.delta.content then
|
||||
elseif delta.content then
|
||||
if
|
||||
ctx.returned_think_start_tag ~= nil and (ctx.returned_think_end_tag == nil or not ctx.returned_think_end_tag)
|
||||
then
|
||||
@@ -304,9 +311,9 @@ function M:parse_response(ctx, data_stream, _, opts)
|
||||
end
|
||||
self:add_thinking_message(ctx, "", "generated", opts)
|
||||
end
|
||||
if choice.delta.content ~= vim.NIL then
|
||||
if opts.on_chunk then opts.on_chunk(choice.delta.content) end
|
||||
self:add_text_message(ctx, choice.delta.content, "generating", opts)
|
||||
if delta.content ~= vim.NIL then
|
||||
if opts.on_chunk then opts.on_chunk(delta.content) end
|
||||
self:add_text_message(ctx, delta.content, "generating", opts)
|
||||
end
|
||||
end
|
||||
if choice.finish_reason == "stop" or choice.finish_reason == "eos_token" then
|
||||
|
||||
@@ -694,6 +694,11 @@ local function insert_conflict_contents(bufnr, snippets)
|
||||
for _, snippet in ipairs(snippets) do
|
||||
local start_line, end_line = unpack(snippet.range)
|
||||
|
||||
local first_line_content = lines[start_line]
|
||||
local old_first_line_indentation = ""
|
||||
|
||||
if first_line_content then old_first_line_indentation = Utils.get_indentation(first_line_content) end
|
||||
|
||||
local result = {}
|
||||
table.insert(result, "<<<<<<< HEAD")
|
||||
for i = start_line, end_line do
|
||||
@@ -703,6 +708,14 @@ local function insert_conflict_contents(bufnr, snippets)
|
||||
|
||||
local snippet_lines = vim.split(snippet.content, "\n")
|
||||
|
||||
if #snippet_lines > 0 then
|
||||
local new_first_line_indentation = Utils.get_indentation(snippet_lines[1])
|
||||
if #old_first_line_indentation > #new_first_line_indentation then
|
||||
local line_indentation = old_first_line_indentation:sub(#new_first_line_indentation + 1)
|
||||
snippet_lines = vim.iter(snippet_lines):map(function(line) return line_indentation .. line end):totable()
|
||||
end
|
||||
end
|
||||
|
||||
vim.list_extend(result, snippet_lines)
|
||||
|
||||
table.insert(result, ">>>>>>> Snippet")
|
||||
@@ -1892,10 +1905,19 @@ function Sidebar:add_history_messages(messages)
|
||||
end
|
||||
self.chat_history.messages = history_messages
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
if self.chat_history.title == "untitled" and #messages > 0 then
|
||||
if
|
||||
self.chat_history.title == "untitled"
|
||||
and #messages > 0
|
||||
and messages[1].just_for_display ~= true
|
||||
and messages[1].state == "generated"
|
||||
then
|
||||
self.chat_history.title = "generating..."
|
||||
Llm.summarize_chat_thread_title(messages[1].message.content, function(title)
|
||||
self:reload_chat_history()
|
||||
if title then self.chat_history.title = title end
|
||||
if title then
|
||||
self.chat_history.title = title
|
||||
else
|
||||
self.chat_history.title = "untitled"
|
||||
end
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
end)
|
||||
end
|
||||
@@ -2371,14 +2393,7 @@ function Sidebar:create_input_container()
|
||||
if Config.behaviour.auto_apply_diff_after_generation then self:apply(false) end
|
||||
end, 0)
|
||||
|
||||
if self.chat_history.title == "untitled" then
|
||||
Llm.summarize_chat_thread_title(request, function(title)
|
||||
if title then self.chat_history.title = title end
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
end)
|
||||
else
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
end
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
end
|
||||
|
||||
if request and request ~= "" then
|
||||
@@ -2407,18 +2422,22 @@ function Sidebar:create_input_container()
|
||||
session_ctx = {},
|
||||
})
|
||||
|
||||
---@param dropped_history_messages avante.HistoryMessage[]
|
||||
local function on_memory_summarize(dropped_history_messages)
|
||||
---@param pending_compaction_history_messages avante.HistoryMessage[]
|
||||
local function on_memory_summarize(pending_compaction_history_messages)
|
||||
local history_memory = self.chat_history.memory
|
||||
Llm.summarize_memory(history_memory and history_memory.content, dropped_history_messages, function(memory)
|
||||
if memory then
|
||||
self.chat_history.memory = memory
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
stream_options.memory = memory.content
|
||||
Llm.summarize_memory(
|
||||
history_memory and history_memory.content,
|
||||
pending_compaction_history_messages,
|
||||
function(memory)
|
||||
if memory then
|
||||
self.chat_history.memory = memory
|
||||
Path.history.save(self.code.bufnr, self.chat_history)
|
||||
stream_options.memory = memory.content
|
||||
end
|
||||
stream_options.history_messages = self:get_history_messages_for_api()
|
||||
Llm.stream(stream_options)
|
||||
end
|
||||
stream_options.history_messages = self:get_history_messages_for_api()
|
||||
Llm.stream(stream_options)
|
||||
end)
|
||||
)
|
||||
end
|
||||
|
||||
stream_options.on_memory_summarize = on_memory_summarize
|
||||
|
||||
@@ -114,7 +114,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field messages AvanteLLMMessage[]
|
||||
---@field image_paths? string[]
|
||||
---@field tools? AvanteLLMTool[]
|
||||
---@field dropped_history_messages? avante.HistoryMessage[]
|
||||
---@field pending_compaction_history_messages? AvanteLLMMessage[]
|
||||
---
|
||||
---@class AvanteGeminiMessage
|
||||
---@field role "user"
|
||||
@@ -356,7 +356,7 @@ vim.g.avante_login = vim.g.avante_login
|
||||
---@field tool_result? AvanteLLMToolResult
|
||||
---@field tool_use? AvanteLLMToolUse
|
||||
---
|
||||
---@alias AvanteLLMMemorySummarizeCallback fun(dropped_history_messages: avante.HistoryMessage[]): nil
|
||||
---@alias AvanteLLMMemorySummarizeCallback fun(pending_compaction_history_messages: avante.HistoryMessage[]): nil
|
||||
---
|
||||
---@alias AvanteLLMToolUseState "generating" | "generated" | "running" | "succeeded" | "failed"
|
||||
---@alias avante.GenerateState "generating" | "tool calling" | "failed" | "succeeded" | "cancelled" | "searching" | "thinking"
|
||||
|
||||
Reference in New Issue
Block a user