feat: Add claude tools to vertex claude provider (#1559)
* feat: Add claude tools to vertex claude provider * fix: export transform_tool from claude.lua * Include type for transfor_tool --------- Co-authored-by: Miguelo Sana <miguelo@incubeta.com>
This commit is contained in:
@@ -2,9 +2,20 @@ local Utils = require("avante.utils")
|
|||||||
local Clipboard = require("avante.clipboard")
|
local Clipboard = require("avante.clipboard")
|
||||||
local P = require("avante.providers")
|
local P = require("avante.providers")
|
||||||
|
|
||||||
|
---@class AvanteProviderFunctor
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
M.api_key_name = "ANTHROPIC_API_KEY"
|
||||||
|
M.support_prompt_caching = true
|
||||||
|
|
||||||
|
M.role_map = {
|
||||||
|
user = "user",
|
||||||
|
assistant = "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
---@param tool AvanteLLMTool
|
---@param tool AvanteLLMTool
|
||||||
---@return AvanteClaudeTool
|
---@return AvanteClaudeTool
|
||||||
local function transform_tool(tool)
|
function M.transform_tool(tool)
|
||||||
local input_schema_properties = {}
|
local input_schema_properties = {}
|
||||||
local required = {}
|
local required = {}
|
||||||
for _, field in ipairs(tool.param.fields) do
|
for _, field in ipairs(tool.param.fields) do
|
||||||
@@ -25,17 +36,6 @@ local function transform_tool(tool)
|
|||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
---@class AvanteProviderFunctor
|
|
||||||
local M = {}
|
|
||||||
|
|
||||||
M.api_key_name = "ANTHROPIC_API_KEY"
|
|
||||||
M.support_prompt_caching = true
|
|
||||||
|
|
||||||
M.role_map = {
|
|
||||||
user = "user",
|
|
||||||
assistant = "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
function M:parse_messages(opts)
|
function M:parse_messages(opts)
|
||||||
---@type AvanteClaudeMessage[]
|
---@type AvanteClaudeMessage[]
|
||||||
local messages = {}
|
local messages = {}
|
||||||
@@ -278,7 +278,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
local tools = {}
|
local tools = {}
|
||||||
if not disable_tools and prompt_opts.tools then
|
if not disable_tools and prompt_opts.tools then
|
||||||
for _, tool in ipairs(prompt_opts.tools) do
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
table.insert(tools, transform_tool(tool))
|
table.insert(tools, self.transform_tool(tool))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ Vertex.api_key_name = "cmd:gcloud auth print-access-token"
|
|||||||
---@param prompt_opts AvantePromptOptions
|
---@param prompt_opts AvantePromptOptions
|
||||||
function M:parse_curl_args(prompt_opts)
|
function M:parse_curl_args(prompt_opts)
|
||||||
local provider_conf, request_body = P.parse_config(self)
|
local provider_conf, request_body = P.parse_config(self)
|
||||||
|
local disable_tools = provider_conf.disable_tools or false
|
||||||
local location = vim.fn.getenv("LOCATION")
|
local location = vim.fn.getenv("LOCATION")
|
||||||
local project_id = vim.fn.getenv("PROJECT_ID")
|
local project_id = vim.fn.getenv("PROJECT_ID")
|
||||||
local model_id = provider_conf.model or "default-model-id"
|
local model_id = provider_conf.model or "default-model-id"
|
||||||
@@ -30,6 +31,20 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
|
|
||||||
local system_prompt = prompt_opts.system_prompt or ""
|
local system_prompt = prompt_opts.system_prompt or ""
|
||||||
local messages = self:parse_messages(prompt_opts)
|
local messages = self:parse_messages(prompt_opts)
|
||||||
|
|
||||||
|
local tools = {}
|
||||||
|
if not disable_tools and prompt_opts.tools then
|
||||||
|
for _, tool in ipairs(prompt_opts.tools) do
|
||||||
|
table.insert(tools, P.claude.transform_tool(tool))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if self.support_prompt_caching and #tools > 0 then
|
||||||
|
local last_tool = vim.deepcopy(tools[#tools])
|
||||||
|
last_tool.cache_control = { type = "ephemeral" }
|
||||||
|
tools[#tools] = last_tool
|
||||||
|
end
|
||||||
|
|
||||||
request_body = vim.tbl_deep_extend("force", request_body, {
|
request_body = vim.tbl_deep_extend("force", request_body, {
|
||||||
anthropic_version = "vertex-2023-10-16",
|
anthropic_version = "vertex-2023-10-16",
|
||||||
temperature = 0,
|
temperature = 0,
|
||||||
@@ -43,6 +58,7 @@ function M:parse_curl_args(prompt_opts)
|
|||||||
cache_control = { type = "ephemeral" },
|
cache_control = { type = "ephemeral" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
tools = tools,
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -269,6 +269,7 @@ vim.g.avante_login = vim.g.avante_login
|
|||||||
---@field parse_api_key fun(): string | nil
|
---@field parse_api_key fun(): string | nil
|
||||||
---@field parse_stream_data? AvanteStreamParser
|
---@field parse_stream_data? AvanteStreamParser
|
||||||
---@field on_error? fun(result: table<string, any>): nil
|
---@field on_error? fun(result: table<string, any>): nil
|
||||||
|
---@field transform_tool? fun(tool: AvanteLLMTool): AvanteOpenAITool | AvanteClaudeTool
|
||||||
---
|
---
|
||||||
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
---@alias AvanteBedrockPayloadBuilder fun(self: AvanteBedrockModelHandler | AvanteBedrockProviderFunctor, prompt_opts: AvantePromptOptions, request_body: table<string, any>): table<string, any>
|
||||||
---
|
---
|
||||||
|
|||||||
Reference in New Issue
Block a user