feat: ask selected code block (#39)

This commit is contained in:
yetone
2024-08-17 22:29:05 +08:00
committed by GitHub
parent dea737bf05
commit 3dca5f4764
9 changed files with 399 additions and 91 deletions

View File

@@ -19,8 +19,9 @@ Your primary task is to suggest code modifications with precise line number rang
1. Carefully analyze the original code, paying close attention to its structure and line numbers. Line numbers start from 1 and include ALL lines, even empty ones.
2. When suggesting modifications:
a. Explain why the change is necessary or beneficial.
b. Provide the exact code snippet to be replaced using this format:
a. Use the language in the question to reply. If there are non-English parts in the question, use the language of those parts.
b. Explain why the change is necessary or beneficial.
c. Provide the exact code snippet to be replaced using this format:
Replace lines: {{start_line}}-{{end_line}}
```{{language}}
@@ -58,14 +59,12 @@ Replace lines: {{start_line}}-{{end_line}}
Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
]]
local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key then
error("ANTHROPIC_API_KEY environment variable is not set")
end
local user_prompt = base_user_prompt
local tokens = Config.claude.max_tokens
local headers = {
["Content-Type"] = "application/json",
@@ -79,33 +78,56 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
text = string.format("<code>```%s\n%s```</code>", code_lang, code_content),
}
if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end
if selected_code_content then
code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", code_lang, code_content)
end
local message_content = {
code_prompt_obj,
}
if selected_code_content then
local selected_code_obj = {
type = "text",
text = string.format("<code>```%s\n%s```</code>", code_lang, selected_code_content),
}
if Tiktoken.count(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end
table.insert(message_content, selected_code_obj)
end
table.insert(message_content, {
type = "text",
text = string.format("<question>%s</question>", question),
})
local user_prompt = base_user_prompt
local user_prompt_obj = {
type = "text",
text = user_prompt,
}
if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end
if Tiktoken.count(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end
table.insert(message_content, user_prompt_obj)
local body = {
model = Config.claude.model,
system = system_prompt,
messages = {
{
role = "user",
content = {
code_prompt_obj,
{
type = "text",
text = string.format("<question>%s</question>", question),
},
user_prompt_obj,
},
content = message_content,
},
},
stream = true,
@@ -154,21 +176,39 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
})
end
local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("OPENAI_API_KEY")
if not api_key and Config.provider == "openai" then
error("OPENAI_API_KEY environment variable is not set")
end
local user_prompt = base_user_prompt
.. "\n\nQUESTION:\n"
.. question
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question
if selected_code_content then
user_prompt = base_user_prompt
.. "\n\nCODE CONTEXT:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. selected_code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question
end
local url, headers, body
if Config.provider == "azure" then
@@ -258,13 +298,14 @@ end
---@param question string
---@param code_lang string
---@param code_content string
---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
if Config.provider == "openai" or Config.provider == "azure" then
call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
elseif Config.provider == "claude" then
call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
end
end