feat: memory 🧠 (#793)

This commit is contained in:
yetone
2024-11-04 16:20:28 +08:00
committed by GitHub
parent 579ef12f76
commit 1e8abbf798
13 changed files with 197 additions and 154 deletions

View File

@@ -33,29 +33,15 @@ local M = {}
M.api_key_name = "OPENAI_API_KEY"
M.role_map = {
user = "user",
assistant = "assistant",
}
---@param opts AvantePromptOptions
M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end
M.parse_message = function(opts)
---@type OpenAIMessage[]
local user_content = {}
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
for _, image_path in ipairs(opts.image_paths) do
table.insert(user_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end)
else
user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
end
M.get_user_message = function(opts) return table.concat(opts.messages, "\n") end
M.parse_messages = function(opts)
local messages = {}
local provider = P[Config.provider]
local base, _ = P.parse_config(provider)
@@ -68,8 +54,23 @@ M.parse_message = function(opts)
table.insert(messages, { role = "system", content = opts.system_prompt })
end
-- User message after the prompt
table.insert(messages, { role = "user", content = user_content })
vim
.iter(opts.messages)
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
local message_content = messages[#messages].content
if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end
for _, image_path in ipairs(opts.image_paths) do
table.insert(message_content, {
type = "image_url",
image_url = {
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
},
})
end
messages[#messages].content = message_content
end
return messages
end
@@ -128,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
headers = headers,
body = vim.tbl_deep_extend("force", {
model = base.model,
messages = M.parse_message(code_opts),
messages = M.parse_messages(code_opts),
stream = stream,
}, body_opts),
}