feat(tokenizers): support parsing from public URL (#765)

This commit is contained in:
Aaron Pham
2024-10-27 02:17:35 -04:00
committed by GitHub
parent a8e2b9a00c
commit bdbbdec88c
7 changed files with 236 additions and 89 deletions

View File

@@ -76,7 +76,7 @@ Respect and use existing conventions, libraries, etc that are already present in
},
---@type AvanteSupportedProvider
cohere = {
endpoint = "https://api.cohere.com/v1",
endpoint = "https://api.cohere.com/v2",
model = "command-r-plus-08-2024",
timeout = 30000, -- Timeout in milliseconds
temperature = 0,

View File

@@ -2,57 +2,69 @@ local Utils = require("avante.utils")
local P = require("avante.providers")
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
---@alias CohereStreamType "message-start" | "content-start" | "content-delta" | "content-end" | "message-end"
---
---@class CohereChatStreamResponse
---@field event_type "stream-start" | "text-generation" | "stream-end"
---@field is_finished boolean
---
---@class CohereTextGenerationResponse: CohereChatStreamResponse
---@class CohereChatContent
---@field type? CohereStreamType
---@field text string
---
---@class CohereStreamEndResponse: CohereChatStreamResponse
---@field response CohereChatResponse
---@field finish_reason CohereFinishReason
---@class CohereChatMessage
---@field content CohereChatContent
---
---@class CohereChatResponse
---@field text string
---@field generation_id string
---@field chat_history CohereMessage[]
---@field finish_reason CohereFinishReason
---@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}}
---@class CohereChatStreamBase
---@field type CohereStreamType
---@field index integer
---
---@class CohereChatContentDelta: CohereChatStreamBase
---@field type "content-delta" | "content-start" | "content-end"
---@field delta? { message: CohereChatMessage }
---
---@class CohereChatMessageStart: CohereChatStreamBase
---@field type "message-start"
---@field delta { message: { role: "assistant" } }
---
---@class CohereChatMessageEnd: CohereChatStreamBase
---@field type "message-end"
---@field delta { finish_reason: CohereFinishReason, usage: CohereChatUsage }
---
---@class CohereChatUsage
---@field billed_units { input_tokens: integer, output_tokens: integer }
---@field tokens { input_tokens: integer, output_tokens: integer }
---
---@alias CohereChatResponse CohereChatContentDelta | CohereChatMessageStart | CohereChatMessageEnd
---
---@class CohereMessage
---@field role? "USER" | "SYSTEM" | "CHATBOT"
---@field message string
---@field type "text"
---@field text string
---
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "CO_API_KEY"
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
M.parse_message = function(opts)
return {
preamble = opts.system_prompt,
message = table.concat(opts.user_prompts, "\n"),
---@type CohereMessage[]
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
table.insert(acc, { type = "text", text = prompt })
return acc
end)
local messages = {
{ role = "system", content = opts.system_prompt },
{ role = "user", content = user_content },
}
return { messages = messages }
end
M.parse_stream_data = function(data, opts)
---@type CohereChatStreamResponse
---@type CohereChatResponse
local json = vim.json.decode(data)
if json.is_finished then
opts.on_complete(nil)
return
end
if json.event_type ~= nil then
---@cast json CohereStreamEndResponse
if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then
if json.type ~= nil then
if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then
opts.on_complete(nil)
return
end
---@cast json CohereTextGenerationResponse
if json.event_type == "text-generation" then opts.on_chunk(json.text) end
if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end
end
end
@@ -83,4 +95,10 @@ M.parse_curl_args = function(provider, code_opts)
}
end
M.setup = function()
P.env.parse_envvar(M)
require("avante.tokenizers").setup(M.tokenizer_id, false)
vim.g.avante_login = true
end
return M

View File

@@ -8,7 +8,9 @@ local tokenizers = nil
local M = {}
---@param model "gpt-4o" | string
M.setup = function(model)
---@param warning? boolean
M.setup = function(model, warning)
warning = warning or true
vim.defer_fn(function()
local ok, core = pcall(require, "avante_tokenizers")
if not ok then return end
@@ -19,14 +21,15 @@ M.setup = function(model)
core.from_pretrained(model)
end, 1000)
local HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN == nil and model ~= "gpt-4o" then
Utils.warn(
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
{ once = true }
)
if warning then
local HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN == nil and model ~= "gpt-4o" then
Utils.warn(
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
{ once = true }
)
end
end
vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1
end
M.available = function() return tokenizers ~= nil end