refactor: history messages (#1934)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
local Providers = require("avante.providers")
|
||||
local Clipboard = require("avante.clipboard")
|
||||
local OpenAI = require("avante.providers").openai
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
@@ -10,14 +11,32 @@ M.role_map = {
|
||||
user = "user",
|
||||
assistant = "model",
|
||||
}
|
||||
-- M.tokenizer_id = "google/gemma-2b"
|
||||
|
||||
function M:is_disable_stream() return false end
|
||||
|
||||
---@param tool AvanteLLMTool
|
||||
function M:transform_to_function_declaration(tool)
|
||||
local input_schema_properties, required = Utils.llm_tool_param_fields_to_json_schema(tool.param.fields)
|
||||
local parameters = nil
|
||||
if not vim.tbl_isempty(input_schema_properties) then
|
||||
parameters = {
|
||||
type = "object",
|
||||
properties = input_schema_properties,
|
||||
required = required,
|
||||
}
|
||||
end
|
||||
return {
|
||||
name = tool.name,
|
||||
description = tool.get_description and tool.get_description() or tool.description,
|
||||
parameters = parameters,
|
||||
}
|
||||
end
|
||||
|
||||
function M:parse_messages(opts)
|
||||
local contents = {}
|
||||
local prev_role = nil
|
||||
|
||||
local tool_id_to_name = {}
|
||||
vim.iter(opts.messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role then
|
||||
@@ -54,9 +73,27 @@ function M:parse_messages(opts)
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_use" then
|
||||
table.insert(parts, { text = item.name })
|
||||
tool_id_to_name[item.id] = item.name
|
||||
role = "model"
|
||||
table.insert(parts, {
|
||||
functionCall = {
|
||||
name = item.name,
|
||||
args = item.input,
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "tool_result" then
|
||||
table.insert(parts, { text = item.content })
|
||||
role = "function"
|
||||
local ok, content = pcall(vim.json.decode, item.content)
|
||||
if not ok then content = item.content end
|
||||
table.insert(parts, {
|
||||
functionResponse = {
|
||||
name = tool_id_to_name[item.tool_use_id],
|
||||
response = {
|
||||
name = tool_id_to_name[item.tool_use_id],
|
||||
content = content,
|
||||
},
|
||||
},
|
||||
})
|
||||
elseif type(item) == "table" and item.type == "thinking" then
|
||||
table.insert(parts, { text = item.thinking })
|
||||
elseif type(item) == "table" and item.type == "redacted_thinking" then
|
||||
@@ -96,22 +133,43 @@ end
|
||||
function M:parse_response(ctx, data_stream, _, opts)
|
||||
local ok, json = pcall(vim.json.decode, data_stream)
|
||||
if not ok then opts.on_stop({ reason = "error", error = json }) end
|
||||
if json.candidates then
|
||||
if #json.candidates > 0 then
|
||||
if json.candidates[1].finishReason and json.candidates[1].finishReason == "STOP" then
|
||||
opts.on_chunk(json.candidates[1].content.parts[1].text)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
else
|
||||
opts.on_chunk(json.candidates[1].content.parts[1].text)
|
||||
if json.candidates and #json.candidates > 0 then
|
||||
local candidate = json.candidates[1]
|
||||
---@type AvanteLLMToolUse[]
|
||||
local tool_use_list = {}
|
||||
for _, part in ipairs(candidate.content.parts) do
|
||||
if part.text then
|
||||
if opts.on_chunk then opts.on_chunk(part.text) end
|
||||
OpenAI:add_text_message(ctx, part.text, "generating", opts)
|
||||
elseif part.functionCall then
|
||||
if not ctx.function_call_id then ctx.function_call_id = 0 end
|
||||
ctx.function_call_id = ctx.function_call_id + 1
|
||||
local tool_use = {
|
||||
id = ctx.session_id .. "-" .. tostring(ctx.function_call_id),
|
||||
name = part.functionCall.name,
|
||||
input_json = vim.json.encode(part.functionCall.args),
|
||||
}
|
||||
table.insert(tool_use_list, tool_use)
|
||||
OpenAI:add_tool_use_message(tool_use, "generated", opts)
|
||||
end
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
if candidate.finishReason and candidate.finishReason == "STOP" then
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
if #tool_use_list > 0 then
|
||||
opts.on_stop({ reason = "tool_use", tool_use_list = tool_use_list })
|
||||
else
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
end
|
||||
else
|
||||
OpenAI:finish_pending_messages(ctx, opts)
|
||||
opts.on_stop({ reason = "complete" })
|
||||
end
|
||||
end
|
||||
|
||||
function M:parse_curl_args(prompt_opts)
|
||||
local provider_conf, request_body = P.parse_config(self)
|
||||
local provider_conf, request_body = Providers.parse_config(self)
|
||||
local disable_tools = provider_conf.disable_tools or false
|
||||
|
||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||
generationConfig = {
|
||||
@@ -125,6 +183,21 @@ function M:parse_curl_args(prompt_opts)
|
||||
local api_key = self.parse_api_key()
|
||||
if api_key == nil then error("Cannot get the gemini api key!") end
|
||||
|
||||
local function_declarations = {}
|
||||
if not disable_tools and prompt_opts.tools then
|
||||
for _, tool in ipairs(prompt_opts.tools) do
|
||||
table.insert(function_declarations, self:transform_to_function_declaration(tool))
|
||||
end
|
||||
end
|
||||
|
||||
if #function_declarations > 0 then
|
||||
request_body.tools = {
|
||||
{
|
||||
functionDeclarations = function_declarations,
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
return {
|
||||
url = Utils.url_join(
|
||||
provider_conf.endpoint,
|
||||
|
||||
Reference in New Issue
Block a user