diff --git a/lua/avante/config.lua b/lua/avante/config.lua index fcc3cb7..54f70b5 100644 --- a/lua/avante/config.lua +++ b/lua/avante/config.lua @@ -684,13 +684,15 @@ M._defaults = { ---@diagnostic disable-next-line: missing-fields M._options = {} +local function get_config_dir_path() return Utils.join_paths(vim.fn.expand("~"), ".config", "avante.nvim") end +local function get_config_file_path() return Utils.join_paths(get_config_dir_path(), "config.json") end + --- Function to save the last used model ---@param model_name string function M.save_last_model(model_name, provider_name) - local config_dir = Utils.join_paths(vim.fn.expand("~"), ".config", "avante.nvim") - local storage_path = Utils.join_paths(config_dir, "config.json") + local config_dir = get_config_dir_path() + local storage_path = get_config_file_path() - -- 确保目录存在 if not Utils.path_exists(config_dir) then vim.fn.mkdir(config_dir, "p") end local file = io.open(storage_path, "w") @@ -700,29 +702,64 @@ function M.save_last_model(model_name, provider_name) end end ---- Function to load the last used model ----@return string|nil, string|nil -function M.load_last_model() - local config_dir = Utils.join_paths(vim.fn.expand("~"), ".config", "avante.nvim") - local storage_path = Utils.join_paths(config_dir, "config.json") +--- Retrieves names of the last used model and provider. May remove saved config if it is deemed invalid +---@param known_providers table +---@return string|nil Model name +---@return string|nil Provider name +function M.get_last_used_model(known_providers) + local storage_path = get_config_file_path() local file = io.open(storage_path, "r") if file then local content = file:read("*a") file:close() - if content and content ~= "" then - local success, data = pcall(vim.json.decode, content) - if success and data and data.last_model and data.last_provider then - return data.last_model, data.last_provider - elseif success and data and data.last_model then - return data.last_model, nil - else - Utils.warn("Invalid or corrupt JSON in last model file: " .. storage_path, { title = "Avante" }) - end - else - Utils.warn("Last model file is empty: " .. storage_path, { title = "Avante" }) + + if not content or content == "" then + Utils.warn("Last used model file is empty: " .. storage_path) + -- Remove to not have repeated warnings + os.remove(storage_path) end + + local success, data = pcall(vim.json.decode, content) + if not success or not data or not data.last_model or data.last_model == "" or data.last_provider == "" then + Utils.warn("Invalid or corrupt JSON in last used model file: " .. storage_path) + -- Rename instead of deleting so user can examine contents + os.rename(storage_path, storage_path .. ".bad") + return + end + + if data.last_provider and not known_providers[data.last_provider] then + Utils.warn( + "Provider " .. data.last_provider .. " is no longer a valid provider, falling back to default configuration" + ) + os.remove(storage_path) + return + end + + return data.last_model, data.last_provider + end +end + +---Applies given model and provider to the config +---@param config avante.Config +---@param model_name string +---@param provider_name? string +local function apply_model_selection(config, model_name, provider_name) + local provider_list = config.providers or {} + local current_provider_name = config.provider + + local target_provider_name = provider_name or current_provider_name + local target_provider = provider_list[target_provider_name] + + if not target_provider then return end + + local current_provider_data = provider_list[current_provider_name] + local current_model_name = current_provider_data and current_provider_data.model + + if target_provider_name ~= current_provider_name or model_name ~= current_model_name then + config.provider = target_provider_name + target_provider.model = model_name + Utils.info(string.format("Using previously selected model: %s/%s", target_provider_name, model_name)) end - return nil end ---@param opts table|nil -- Optional table parameter for configuration settings @@ -861,23 +898,8 @@ function M.setup(opts) } ) - local last_model, last_provider = M.load_last_model() - if last_model then - local original_provider = merged.provider - local original_model = merged.providers - and merged.providers[original_provider] - and merged.providers[original_provider].model - if last_provider then merged.provider = last_provider end - if merged.providers and merged.provider and merged.providers[merged.provider] then - merged.providers[merged.provider].model = last_model - if last_model ~= original_model or last_provider ~= original_provider then - Utils.info( - "Using last model: " .. merged.provider .. "/" .. merged.providers[merged.provider].model, - { title = "Avante" } - ) - end - end - end + local last_model, last_provider = M.get_last_used_model(merged.providers or {}) + if last_model then apply_model_selection(merged, last_model, last_provider) end M._options = merged