feat(tokenizers): support parsing from public URL (#765)
This commit is contained in:
@@ -76,7 +76,7 @@ Respect and use existing conventions, libraries, etc that are already present in
|
||||
},
|
||||
---@type AvanteSupportedProvider
|
||||
cohere = {
|
||||
endpoint = "https://api.cohere.com/v1",
|
||||
endpoint = "https://api.cohere.com/v2",
|
||||
model = "command-r-plus-08-2024",
|
||||
timeout = 30000, -- Timeout in milliseconds
|
||||
temperature = 0,
|
||||
|
||||
@@ -2,57 +2,69 @@ local Utils = require("avante.utils")
|
||||
local P = require("avante.providers")
|
||||
|
||||
---@alias CohereFinishReason "COMPLETE" | "LENGTH" | "ERROR"
|
||||
---@alias CohereStreamType "message-start" | "content-start" | "content-delta" | "content-end" | "message-end"
|
||||
---
|
||||
---@class CohereChatStreamResponse
|
||||
---@field event_type "stream-start" | "text-generation" | "stream-end"
|
||||
---@field is_finished boolean
|
||||
---
|
||||
---@class CohereTextGenerationResponse: CohereChatStreamResponse
|
||||
---@class CohereChatContent
|
||||
---@field type? CohereStreamType
|
||||
---@field text string
|
||||
---
|
||||
---@class CohereStreamEndResponse: CohereChatStreamResponse
|
||||
---@field response CohereChatResponse
|
||||
---@field finish_reason CohereFinishReason
|
||||
---@class CohereChatMessage
|
||||
---@field content CohereChatContent
|
||||
---
|
||||
---@class CohereChatResponse
|
||||
---@field text string
|
||||
---@field generation_id string
|
||||
---@field chat_history CohereMessage[]
|
||||
---@field finish_reason CohereFinishReason
|
||||
---@field meta {api_version: {version: integer}, billed_units: {input_tokens: integer, output_tokens: integer}, tokens: {input_tokens: integer, output_tokens: integer}}
|
||||
---@class CohereChatStreamBase
|
||||
---@field type CohereStreamType
|
||||
---@field index integer
|
||||
---
|
||||
---@class CohereChatContentDelta: CohereChatStreamBase
|
||||
---@field type "content-delta" | "content-start" | "content-end"
|
||||
---@field delta? { message: CohereChatMessage }
|
||||
---
|
||||
---@class CohereChatMessageStart: CohereChatStreamBase
|
||||
---@field type "message-start"
|
||||
---@field delta { message: { role: "assistant" } }
|
||||
---
|
||||
---@class CohereChatMessageEnd: CohereChatStreamBase
|
||||
---@field type "message-end"
|
||||
---@field delta { finish_reason: CohereFinishReason, usage: CohereChatUsage }
|
||||
---
|
||||
---@class CohereChatUsage
|
||||
---@field billed_units { input_tokens: integer, output_tokens: integer }
|
||||
---@field tokens { input_tokens: integer, output_tokens: integer }
|
||||
---
|
||||
---@alias CohereChatResponse CohereChatContentDelta | CohereChatMessageStart | CohereChatMessageEnd
|
||||
---
|
||||
---@class CohereMessage
|
||||
---@field role? "USER" | "SYSTEM" | "CHATBOT"
|
||||
---@field message string
|
||||
---@field type "text"
|
||||
---@field text string
|
||||
---
|
||||
---@class AvanteProviderFunctor
|
||||
local M = {}
|
||||
|
||||
M.api_key_name = "CO_API_KEY"
|
||||
M.tokenizer_id = "CohereForAI/c4ai-command-r-plus-08-2024"
|
||||
M.tokenizer_id = "https://storage.googleapis.com/cohere-public/tokenizers/command-r-08-2024.json"
|
||||
|
||||
M.parse_message = function(opts)
|
||||
return {
|
||||
preamble = opts.system_prompt,
|
||||
message = table.concat(opts.user_prompts, "\n"),
|
||||
---@type CohereMessage[]
|
||||
local user_content = vim.iter(opts.user_prompts):fold({}, function(acc, prompt)
|
||||
table.insert(acc, { type = "text", text = prompt })
|
||||
return acc
|
||||
end)
|
||||
local messages = {
|
||||
{ role = "system", content = opts.system_prompt },
|
||||
{ role = "user", content = user_content },
|
||||
}
|
||||
return { messages = messages }
|
||||
end
|
||||
|
||||
M.parse_stream_data = function(data, opts)
|
||||
---@type CohereChatStreamResponse
|
||||
---@type CohereChatResponse
|
||||
local json = vim.json.decode(data)
|
||||
if json.is_finished then
|
||||
opts.on_complete(nil)
|
||||
return
|
||||
end
|
||||
if json.event_type ~= nil then
|
||||
---@cast json CohereStreamEndResponse
|
||||
if json.event_type == "stream-end" and json.finish_reason == "COMPLETE" then
|
||||
if json.type ~= nil then
|
||||
if json.type == "message-end" and json.delta.finish_reason == "COMPLETE" then
|
||||
opts.on_complete(nil)
|
||||
return
|
||||
end
|
||||
---@cast json CohereTextGenerationResponse
|
||||
if json.event_type == "text-generation" then opts.on_chunk(json.text) end
|
||||
if json.type == "content-delta" then opts.on_chunk(json.delta.message.content.text) end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -83,4 +95,10 @@ M.parse_curl_args = function(provider, code_opts)
|
||||
}
|
||||
end
|
||||
|
||||
M.setup = function()
|
||||
P.env.parse_envvar(M)
|
||||
require("avante.tokenizers").setup(M.tokenizer_id, false)
|
||||
vim.g.avante_login = true
|
||||
end
|
||||
|
||||
return M
|
||||
|
||||
@@ -8,7 +8,9 @@ local tokenizers = nil
|
||||
local M = {}
|
||||
|
||||
---@param model "gpt-4o" | string
|
||||
M.setup = function(model)
|
||||
---@param warning? boolean
|
||||
M.setup = function(model, warning)
|
||||
warning = warning or true
|
||||
vim.defer_fn(function()
|
||||
local ok, core = pcall(require, "avante_tokenizers")
|
||||
if not ok then return end
|
||||
@@ -19,14 +21,15 @@ M.setup = function(model)
|
||||
core.from_pretrained(model)
|
||||
end, 1000)
|
||||
|
||||
local HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
if HF_TOKEN == nil and model ~= "gpt-4o" then
|
||||
Utils.warn(
|
||||
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
|
||||
{ once = true }
|
||||
)
|
||||
if warning then
|
||||
local HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
if HF_TOKEN == nil and model ~= "gpt-4o" then
|
||||
Utils.warn(
|
||||
"Please set HF_TOKEN environment variable to use HuggingFace tokenizer if " .. model .. " is gated",
|
||||
{ once = true }
|
||||
)
|
||||
end
|
||||
end
|
||||
vim.env.HF_HUB_DISABLE_PROGRESS_BARS = 1
|
||||
end
|
||||
|
||||
M.available = function() return tokenizers ~= nil end
|
||||
|
||||
Reference in New Issue
Block a user