feat: memory 🧠 (#793)

This commit is contained in:
yetone
2024-11-04 16:20:28 +08:00
committed by GitHub
parent 579ef12f76
commit 1e8abbf798
13 changed files with 197 additions and 154 deletions

View File

@@ -13,7 +13,7 @@ local M = {}
M.api_key_name = "AZURE_OPENAI_API_KEY"
M.parse_message = O.parse_message
M.parse_messages = O.parse_messages
M.parse_response = O.parse_response
M.parse_curl_args = function(provider, code_opts)
@@ -34,7 +34,7 @@ M.parse_curl_args = function(provider, code_opts)
insecure = base.allow_insecure,
headers = headers,
body = vim.tbl_deep_extend("force", {
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = true,
}, body_opts),
}

View File

@@ -13,8 +13,8 @@ local P = require("avante.providers")
---@field type "image"
---@field source {type: "base64", media_type: string, data: string}
---
---@class AvanteClaudeMessage: AvanteBaseMessage
---@field role "user"
---@class AvanteClaudeMessage
---@field role "user" | "assistant"
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
---@class AvanteProviderFunctor
@@ -23,11 +23,44 @@ local M = {}
M.api_key_name = "ANTHROPIC_API_KEY"
M.use_xml_format = true
M.parse_message = function(opts)
---@type AvanteClaudeMessage[]
local message_content = {}
M.role_map = {
user = "user",
assistant = "assistant",
}
if Clipboard.support_paste_image() and opts.image_paths then
M.parse_messages = function(opts)
---@type AvanteClaudeMessage[]
local messages = {}
---@type {idx: integer, length: integer}[]
local messages_with_length = {}
for idx, message in ipairs(opts.messages) do
table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) })
end
table.sort(messages_with_length, function(a, b) return a.length > b.length end)
---@type table<integer, boolean>
local top_three = {}
for i = 1, math.min(3, #messages_with_length) do
top_three[messages_with_length[i].idx] = true
end
for idx, message in ipairs(opts.messages) do
table.insert(messages, {
role = M.role_map[message.role],
content = {
{
type = "text",
text = message.content,
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
},
},
})
end
if Clipboard.support_paste_image() and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image",
@@ -38,36 +71,10 @@ M.parse_message = function(opts)
},
})
end
messages[#messages].content = message_content
end
---@type {idx: integer, length: integer}[]
local user_prompts_with_length = {}
for idx, user_prompt in ipairs(opts.user_prompts) do
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
end
table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end)
---@type table<integer, boolean>
local top_three = {}
for i = 1, math.min(3, #user_prompts_with_length) do
top_three[user_prompts_with_length[i].idx] = true
end
for idx, prompt_data in ipairs(opts.user_prompts) do
table.insert(message_content, {
type = "text",
text = prompt_data,
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
})
end
return {
{
role = "user",
content = message_content,
},
}
return messages
end
M.parse_response = function(data_stream, event_state, opts)
@@ -96,7 +103,7 @@ M.parse_curl_args = function(provider, prompt_opts)
}
if not P.env.is_local("claude") then headers["x-api-key"] = provider.parse_api_key() end
local messages = M.parse_message(prompt_opts)
local messages = M.parse_messages(prompt_opts)
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages",

View File

@@ -42,17 +42,18 @@ local M = {}
M.api_key_name = "CO_API_KEY"
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
M.role_map = {
user = "user",
assistant = "assistant",
}
M.parse_message = function(opts)
---@type CohereMessage[]
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
M.parse_messages = function(opts)
local messages = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = user_content },
}
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
return { messages = messages }
end
@@ -91,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts)
body = vim.tbl_deep_extend("force", {
model = base.model,
stream = true,
}, M.parse_message(code_opts), body_opts),
}, M.parse_messages(code_opts), body_opts),
}
end

View File

@@ -118,12 +118,19 @@ M.state = nil
M.api_key_name = P.AVANTE_INTERNAL_KEY
M.tokenizer_id = "gpt-4o"
M.role_map = {
user = "user",
assistant = "assistant",
}
M.parse_message = function(opts)
return {
M.parse_messages = function(opts)
local messages = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = table.concat(opts.user_prompts, "\n") },
}
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
return messages
end
M.parse_response = O.parse_response
@@ -146,7 +153,7 @@ M.parse_curl_args = function(provider, code_opts)
},
body = vim.tbl_deep_extend("force", {
model = base.model,
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = true,
}, body_opts),
}

View File

@@ -6,10 +6,34 @@ local Clipboard = require("avante.clipboard")
local M = {}
M.api_key_name = "GEMINI_API_KEY"
M.role_map = {
user = "user",
assistant = "model",
}
-- M.tokenizer_id = "google/gemma-2b"
M.parse_message = function(opts)
local message_content = {}
M.parse_messages = function(opts)
local contents = {}
local prev_role = nil
vim.iter(opts.messages):each(function(message)
local role = message.role
if role == prev_role then
if role == "user" then
table.insert(contents, { role = "model", parts = {
{ text = "Ok, I understand." },
} })
else
table.insert(contents, { role = "user", parts = {
{ text = "Ok" },
} })
end
end
prev_role = role
table.insert(contents, { role = M.role_map[role] or role, parts = {
{ text = message.content },
} })
end)
if Clipboard.support_paste_image() and opts.image_paths then
for _, image_path in ipairs(opts.image_paths) do
@@ -20,13 +44,10 @@ M.parse_message = function(opts)
},
}
table.insert(message_content, image_data)
table.insert(contents[#contents].parts, image_data)
end
end
-- insert a part into parts
table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") })
return {
systemInstruction = {
role = "user",
@@ -36,12 +57,7 @@ M.parse_message = function(opts)
},
},
},
contents = {
{
role = "user",
parts = message_content,
},
},
contents = contents,
}
end
@@ -78,7 +94,7 @@ M.parse_curl_args = function(provider, code_opts)
proxy = base.proxy,
insecure = base.allow_insecure,
headers = { ["Content-Type"] = "application/json" },
body = vim.tbl_deep_extend("force", {}, M.parse_message(code_opts), body_opts),
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
}
end

View File

@@ -14,22 +14,22 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field on_chunk AvanteChunkParser
---@field on_complete AvanteCompleteParser
---
---@class AvanteLLMMessage
---@field role "user" | "assistant"
---@field content string
---
---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string
---@field user_prompts string[]
---@field messages AvanteLLMMessage[]
---@field image_paths? string[]
---
---@class AvanteBaseMessage
---@field role "user" | "system"
---@field content string
---
---@class AvanteGeminiMessage
---@field role "user"
---@field parts { text: string }[]
---
---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage
---
---@alias AvanteMessageParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
---
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>}
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
@@ -65,13 +65,14 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
---@field parse_api_key? fun(): string | nil
---
---@class AvanteProviderFunctor
---@field parse_message AvanteMessageParser
---@field role_map table<"user" | "assistant", string>
---@field parse_messages AvanteMessagesParser
---@field parse_response AvanteResponseParser
---@field parse_curl_args AvanteCurlArgsParser
---@field setup fun(): nil
---@field has fun(): boolean
---@field api_key_name string
---@field tokenizer_id [string] | "gpt-4o"
---@field tokenizer_id string | "gpt-4o"
---@field use_xml_format boolean
---@field model? string
---@field parse_api_key fun(): string | nil

View File

@@ -33,29 +33,15 @@ local M = {}
M.api_key_name = "OPENAI_API_KEY"
M.role_map = {
user = "user",
assistant = "assistant",
}
---@param opts AvantePromptOptions
M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end
M.parse_message = function(opts)
---@type OpenAIMessage[]
local user_content = {}
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
for _, image_path in ipairs(opts.image_paths) do
table.insert(user_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end)
else
user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
end
M.get_user_message = function(opts) return table.concat(opts.messages, "\n") end
M.parse_messages = function(opts)
local messages = {}
local provider = P[Config.provider]
local base, _ = P.parse_config(provider)
@@ -68,8 +54,23 @@ M.parse_message = function(opts)
table.insert(messages, { role = "system", content = opts.system_prompt })
end
-- User message after the prompt
table.insert(messages, { role = "user", content = user_content })
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
messages[#messages].content = message_content
end
return messages
end
@@ -128,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
headers = headers,
body = vim.tbl_deep_extend("force", {
model = base.model,
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = stream,
}, body_opts),
}