fix: use the original prompts (#397)

This commit is contained in:
yetone
2024-08-30 22:21:50 +08:00
committed by GitHub
parent 5cf85d59bd
commit 104484f17c
9 changed files with 216 additions and 135 deletions

View File

@@ -24,19 +24,19 @@ M.parse_message = function(prompt_opts)
end
end
local user_prompt = prompt_opts.user_prompt
for _, user_prompt in ipairs(prompt_opts.user_prompts) do
local user_prompt_obj = {
type = "text",
text = user_prompt,
}
local user_prompt_obj = {
type = "text",
text = user_prompt,
}
if Utils.tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end
if Utils.tokens.calculate_tokens(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
table.insert(message_content, user_prompt_obj)
end
table.insert(message_content, user_prompt_obj)
return {
{
role = "user",
@@ -75,6 +75,8 @@ M.parse_curl_args = function(provider, prompt_opts)
headers["x-api-key"] = provider.parse_api_key()
end
local messages = M.parse_message(prompt_opts)
return {
url = Utils.trim(base.endpoint, { suffix = "/" }) .. "/v1/messages",
proxy = base.proxy,
@@ -89,7 +91,7 @@ M.parse_curl_args = function(provider, prompt_opts)
cache_control = { type = "ephemeral" },
},
},
messages = M.parse_message(prompt_opts),
messages = messages,
stream = true,
}, body_opts),
}

View File

@@ -31,9 +31,14 @@ local M = {}
M.api_key_name = "CO_API_KEY"
M.parse_message = function(opts)
local user_prompt = ""
for _, user_prompt_ in ipairs(opts.user_prompts) do
user_prompt = user_prompt .. "\n\n" .. user_prompt_
end
return {
preamble = opts.system_prompt,
message = opts.user_prompt,
message = user_prompt,
}
end

View File

@@ -24,9 +24,11 @@ M.parse_message = function(opts)
end
-- insert a part into parts
table.insert(message_content, {
text = opts.user_prompt,
})
for _, user_prompt in ipairs(opts.user_prompts) do
table.insert(message_content, {
text = user_prompt,
})
end
return {
systemInstruction = {

View File

@@ -10,7 +10,7 @@ local Dressing = require("avante.ui.dressing")
---
---@class AvantePromptOptions: table<[string], string>
---@field system_prompt string
---@field user_prompt string
---@field user_prompts string[]
---@field image_paths? string[]
---
---@class AvanteBaseMessage

View File

@@ -28,6 +28,11 @@ local M = {}
M.api_key_name = "OPENAI_API_KEY"
M.parse_message = function(opts)
local user_prompt = ""
for _, user_prompt_ in ipairs(opts.user_prompts) do
user_prompt = user_prompt .. "\n\n" .. user_prompt_
end
---@type string | OpenAIMessage[]
local user_content
if Config.behaviour.support_paste_from_clipboard and opts.image_paths and #opts.image_paths > 0 then
@@ -40,9 +45,9 @@ M.parse_message = function(opts)
},
})
end
table.insert(user_content, { type = "text", text = opts.user_prompt })
table.insert(user_content, { type = "text", text = user_prompt })
else
user_content = opts.user_prompt
user_content = user_prompt
end
return {