feat: memory 🧠 (#793)
This commit is contained in:
@@ -13,7 +13,7 @@ local M = {}
|
||||
|
||||
M.api_key_name = "AZURE_OPENAI_API_KEY"
|
||||
|
||||
M.parse_message = O.parse_message
|
||||
M.parse_messages = O.parse_messages
|
||||
M.parse_response = O.parse_response
|
||||
|
||||
M.parse_curl_args = function(provider, code_opts)
|
||||
@@ -34,7 +34,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
insecure = base.allow_insecure,
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
messages = M.parse_message(code_opts),
|
||||
messages = M.parse_messages(code_opts),
|
||||
stream = true,
|
||||
}, body_opts),
|
||||
}
|
||||
|
||||
@@ -13,8 +13,8 @@ local P = require("avante.providers")
|
||||
---@field type "image"
|
||||
---@field source {type: "base64", media_type: string, data: string}
|
||||
---
|
||||
---@class AvanteClaudeMessage: AvanteBaseMessage
|
||||
---@field role "user"
|
||||
---@class AvanteClaudeMessage
|
||||
---@field role "user" | "assistant"
|
||||
---@field content [AvanteClaudeTextMessage | AvanteClaudeImageMessage][]
|
||||
|
||||
---@class AvanteProviderFunctor
|
||||
@@ -23,11 +23,44 @@ local M = {}
|
||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||
M.use_xml_format = true
|
||||
|
||||
M.parse_message = function(opts)
|
||||
---@type AvanteClaudeMessage[]
|
||||
local message_content = {}
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
if Clipboard.support_paste_image() and opts.image_paths then
|
||||
M.parse_messages = function(opts)
|
||||
---@type AvanteClaudeMessage[]
|
||||
local messages = {}
|
||||
|
||||
---@type {idx: integer, length: integer}[]
|
||||
local messages_with_length = {}
|
||||
for idx, message in ipairs(opts.messages) do
|
||||
table.insert(messages_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(message.content) })
|
||||
end
|
||||
|
||||
table.sort(messages_with_length, function(a, b) return a.length > b.length end)
|
||||
|
||||
---@type table<integer, boolean>
|
||||
local top_three = {}
|
||||
for i = 1, math.min(3, #messages_with_length) do
|
||||
top_three[messages_with_length[i].idx] = true
|
||||
end
|
||||
|
||||
for idx, message in ipairs(opts.messages) do
|
||||
table.insert(messages, {
|
||||
role = M.role_map[message.role],
|
||||
content = {
|
||||
{
|
||||
type = "text",
|
||||
text = message.content,
|
||||
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
|
||||
},
|
||||
},
|
||||
})
|
||||
end
|
||||
|
||||
if Clipboard.support_paste_image() and opts.image_paths and #opts.image_paths > 0 then
|
||||
local message_content = messages[#messages].content
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
table.insert(message_content, {
|
||||
type = "image",
|
||||
@@ -38,36 +71,10 @@ M.parse_message = function(opts)
|
||||
},
|
||||
})
|
||||
end
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
---@type {idx: integer, length: integer}[]
|
||||
local user_prompts_with_length = {}
|
||||
for idx, user_prompt in ipairs(opts.user_prompts) do
|
||||
table.insert(user_prompts_with_length, { idx = idx, length = Utils.tokens.calculate_tokens(user_prompt) })
|
||||
end
|
||||
|
||||
table.sort(user_prompts_with_length, function(a, b) return a.length > b.length end)
|
||||
|
||||
---@type table<integer, boolean>
|
||||
local top_three = {}
|
||||
for i = 1, math.min(3, #user_prompts_with_length) do
|
||||
top_three[user_prompts_with_length[i].idx] = true
|
||||
end
|
||||
|
||||
for idx, prompt_data in ipairs(opts.user_prompts) do
|
||||
table.insert(message_content, {
|
||||
type = "text",
|
||||
text = prompt_data,
|
||||
cache_control = top_three[idx] and { type = "ephemeral" } or nil,
|
||||
})
|
||||
end
|
||||
|
||||
return {
|
||||
{
|
||||
role = "user",
|
||||
content = message_content,
|
||||
},
|
||||
}
|
||||
return messages
|
||||
end
|
||||
|
||||
M.parse_response = function(data_stream, event_state, opts)
|
||||
@@ -96,7 +103,7 @@ M.parse_curl_args = function(provider, prompt_opts)
|
||||
}
|
||||
if not P.env.is_local("claude") then headers["x-api-key"] = provider.parse_api_key() end
|
||||
|
||||
local messages = M.parse_message(prompt_opts)
|
||||
local messages = M.parse_messages(prompt_opts)
|
||||
|
||||
return {
|
||||
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages",
|
||||
|
||||
@@ -42,17 +42,18 @@ local M = {}
|
||||
|
||||
M.api_key_name = "CO_API_KEY"
|
||||
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
M.parse_message = function(opts)
|
||||
---@type CohereMessage[]
|
||||
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
|
||||
table.insert(acc, { type = "text", text = prompt })
|
||||
return acc
|
||||
end)
|
||||
M.parse_messages = function(opts)
|
||||
local messages = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = user_content },
|
||||
}
|
||||
vim
|
||||
.iter(opts.messages)
|
||||
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||
return { messages = messages }
|
||||
end
|
||||
|
||||
@@ -91,7 +92,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
stream = true,
|
||||
}, M.parse_message(code_opts), body_opts),
|
||||
}, M.parse_messages(code_opts), body_opts),
|
||||
}
|
||||
end
|
||||
|
||||
|
||||
@@ -118,12 +118,19 @@ M.state = nil
|
||||
|
||||
M.api_key_name = P.AVANTE_INTERNAL_KEY
|
||||
M.tokenizer_id = "gpt-4o"
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
M.parse_message = function(opts)
|
||||
return {
|
||||
M.parse_messages = function(opts)
|
||||
local messages = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = table.concat(opts.user_prompts, "\n") },
|
||||
}
|
||||
vim
|
||||
.iter(opts.messages)
|
||||
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||
return messages
|
||||
end
|
||||
|
||||
M.parse_response = O.parse_response
|
||||
@@ -146,7 +153,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
},
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
messages = M.parse_message(code_opts),
|
||||
messages = M.parse_messages(code_opts),
|
||||
stream = true,
|
||||
}, body_opts),
|
||||
}
|
||||
|
||||
@@ -6,10 +6,34 @@ local Clipboard = require("avante.clipboard")
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "GEMINI_API_KEY"
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "model",
|
||||
}
|
||||
-- M.tokenizer_id = "google/gemma-2b"
|
||||
|
||||
M.parse_message = function(opts)
|
||||
local message_content = {}
|
||||
M.parse_messages = function(opts)
|
||||
local contents = {}
|
||||
local prev_role = nil
|
||||
|
||||
vim.iter(opts.messages):each(function(message)
|
||||
local role = message.role
|
||||
if role == prev_role then
|
||||
if role == "user" then
|
||||
table.insert(contents, { role = "model", parts = {
|
||||
{ text = "Ok, I understand." },
|
||||
} })
|
||||
else
|
||||
table.insert(contents, { role = "user", parts = {
|
||||
{ text = "Ok" },
|
||||
} })
|
||||
end
|
||||
end
|
||||
prev_role = role
|
||||
table.insert(contents, { role = M.role_map[role] or role, parts = {
|
||||
{ text = message.content },
|
||||
} })
|
||||
end)
|
||||
|
||||
if Clipboard.support_paste_image() and opts.image_paths then
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
@@ -20,13 +44,10 @@ M.parse_message = function(opts)
|
||||
},
|
||||
}
|
||||
|
||||
table.insert(message_content, image_data)
|
||||
table.insert(contents[#contents].parts, image_data)
|
||||
end
|
||||
end
|
||||
|
||||
-- insert a part into parts
|
||||
table.insert(message_content, { text = table.concat(opts.user_prompts, "\n") })
|
||||
|
||||
return {
|
||||
systemInstruction = {
|
||||
role = "user",
|
||||
@@ -36,12 +57,7 @@ M.parse_message = function(opts)
|
||||
},
|
||||
},
|
||||
},
|
||||
contents = {
|
||||
{
|
||||
role = "user",
|
||||
parts = message_content,
|
||||
},
|
||||
},
|
||||
contents = contents,
|
||||
}
|
||||
end
|
||||
|
||||
@@ -78,7 +94,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
proxy = base.proxy,
|
||||
insecure = base.allow_insecure,
|
||||
headers = { ["Content-Type"] = "application/json" },
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_message(code_opts), body_opts),
|
||||
body = vim.tbl_deep_extend("force", {}, M.parse_messages(code_opts), body_opts),
|
||||
}
|
||||
end
|
||||
|
||||
|
||||
@@ -14,22 +14,22 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
||||
---@field on_chunk AvanteChunkParser
|
||||
---@field on_complete AvanteCompleteParser
|
||||
---
|
||||
---@class AvanteLLMMessage
|
||||
---@field role "user" | "assistant"
|
||||
---@field content string
|
||||
---
|
||||
---@class AvantePromptOptions: table<[string], string>
|
||||
---@field system_prompt string
|
||||
---@field user_prompts string[]
|
||||
---@field messages AvanteLLMMessage[]
|
||||
---@field image_paths? string[]
|
||||
---
|
||||
---@class AvanteBaseMessage
|
||||
---@field role "user" | "system"
|
||||
---@field content string
|
||||
---
|
||||
---@class AvanteGeminiMessage
|
||||
---@field role "user"
|
||||
---@field parts { text: string }[]
|
||||
---
|
||||
---@alias AvanteChatMessage AvanteClaudeMessage | OpenAIMessage | AvanteGeminiMessage
|
||||
---
|
||||
---@alias AvanteMessageParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
||||
---@alias AvanteMessagesParser fun(opts: AvantePromptOptions): AvanteChatMessage[]
|
||||
---
|
||||
---@class AvanteCurlOutput: {url: string, proxy: string, insecure: boolean, body: table<string, any> | string, headers: table<string, string>}
|
||||
---@alias AvanteCurlArgsParser fun(opts: AvanteProvider | AvanteProviderFunctor, code_opts: AvantePromptOptions): AvanteCurlOutput
|
||||
@@ -65,13 +65,14 @@ local DressingState = { winid = nil, input_winid = nil, input_bufnr = nil }
|
||||
---@field parse_api_key? fun(): string | nil
|
||||
---
|
||||
---@class AvanteProviderFunctor
|
||||
---@field parse_message AvanteMessageParser
|
||||
---@field role_map table<"user" | "assistant", string>
|
||||
---@field parse_messages AvanteMessagesParser
|
||||
---@field parse_response AvanteResponseParser
|
||||
---@field parse_curl_args AvanteCurlArgsParser
|
||||
---@field setup fun(): nil
|
||||
---@field has fun(): boolean
|
||||
---@field api_key_name string
|
||||
---@field tokenizer_id [string] | "gpt-4o"
|
||||
---@field tokenizer_id string | "gpt-4o"
|
||||
---@field use_xml_format boolean
|
||||
---@field model? string
|
||||
---@field parse_api_key fun(): string | nil
|
||||
|
||||
@@ -33,29 +33,15 @@ local M = {}
|
||||
|
||||
M.api_key_name = "OPENAI_API_KEY"
|
||||
|
||||
M.role_map = {
|
||||
user = "user",
|
||||
assistant = "assistant",
|
||||
}
|
||||
|
||||
---@param opts AvantePromptOptions
|
||||
M.get_user_message = function(opts) return table.concat(opts.user_prompts, "\n") end
|
||||
|
||||
M.parse_message = function(opts)
|
||||
---@type OpenAIMessage[]
|
||||
local user_content = {}
|
||||
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
table.insert(user_content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
||||
},
|
||||
})
|
||||
end
|
||||
vim.iter(opts.user_prompts):each(function(prompt) table.insert(user_content, { type = "text", text = prompt }) end)
|
||||
else
|
||||
user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
|
||||
table.insert(acc, { type = "text", text = prompt })
|
||||
return acc
|
||||
end)
|
||||
end
|
||||
M.get_user_message = function(opts) return table.concat(opts.messages, "\n") end
|
||||
|
||||
M.parse_messages = function(opts)
|
||||
local messages = {}
|
||||
local provider = P[Config.provider]
|
||||
local base, _ = P.parse_config(provider)
|
||||
@@ -68,8 +54,23 @@ M.parse_message = function(opts)
|
||||
table.insert(messages, { role = "system", content = opts.system_prompt })
|
||||
end
|
||||
|
||||
-- User message after the prompt
|
||||
table.insert(messages, { role = "user", content = user_content })
|
||||
vim
|
||||
.iter(opts.messages)
|
||||
:each(function(msg) table.insert(messages, { role = M.role_map[msg.role], content = msg.content }) end)
|
||||
|
||||
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
|
||||
local message_content = messages[#messages].content
|
||||
if type(message_content) ~= "table" then message_content = { type = "text", text = message_content } end
|
||||
for _, image_path in ipairs(opts.image_paths) do
|
||||
table.insert(message_content, {
|
||||
type = "image_url",
|
||||
image_url = {
|
||||
url = "data:image/png;base64," .. Clipboard.get_base64_content(image_path),
|
||||
},
|
||||
})
|
||||
end
|
||||
messages[#messages].content = message_content
|
||||
end
|
||||
|
||||
return messages
|
||||
end
|
||||
@@ -128,7 +129,7 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
headers = headers,
|
||||
body = vim.tbl_deep_extend("force", {
|
||||
model = base.model,
|
||||
messages = M.parse_message(code_opts),
|
||||
messages = M.parse_messages(code_opts),
|
||||
stream = stream,
|
||||
}, body_opts),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user