diff --git a/src/api/controller/generate.js b/src/api/controller/generate.js index dadb6a7..21c88db 100644 --- a/src/api/controller/generate.js +++ b/src/api/controller/generate.js @@ -3,16 +3,27 @@ const ErrorLog = require('../../models/Error'); /** * Save a prompt to the database + * @param {Object} params - The prompt data + * @param {string} params.model - The model used + * @param {string} [params.prompt] - The prompt text (for non-chat requests) + * @param {Array} [params.messages] - Array of chat messages (for chat requests) + * @param {Object} [params.request_data] - Complete request data + * @param {Object} [params.token_data] - Token counts and cost calculations + * @returns {Promise} The created prompt record */ -async function savePrompt({ model, prompt, messages, request_data }) { - return Prompt.createPrompt({ model, prompt, messages, request_data }); +async function savePrompt({ model, prompt, messages, request_data, token_data }) { + return Prompt.createPrompt({ model, prompt, messages, request_data, token_data }); } /** * Save an error to the database + * @param {Object} params - The error data + * @param {string} params.error_message - The error message + * @param {Object} [params.details] - Additional error details + * @returns {Promise} The created error record */ async function saveError({ error_message, details }) { - return ErrorLog.createError({ error_message, details }); + return ErrorLog.create({ error_message, details }); } module.exports = { savePrompt, saveError }; diff --git a/src/api/network/generate.js b/src/api/network/generate.js index 6343043..608c910 100644 --- a/src/api/network/generate.js +++ b/src/api/network/generate.js @@ -3,6 +3,7 @@ const fs = require("fs"); const path = require("path"); const { savePrompt, saveError } = require("../controller/generate"); const { processRequestTokens } = require("../../utils/tokenCounter"); +const Prompt = require("../../models/Prompt"); // Constants for file handling const MAX_FILE_SIZE = 30000; // ~30KB per file @@ -455,6 +456,11 @@ async function handleGenerate(req, res) { prompt: isChatRequest ? null : cleanedRequest.prompt, messages: isChatRequest ? cleanedRequest.messages : null, request_data: requestData, + token_data: processRequestTokens( + cleanedRequest, + '', // Empty response initially + parseFloat(process.env.YOUR_MODEL_COST) || 0 + ) }); promptId = prompt.id; @@ -512,7 +518,7 @@ async function handleGenerate(req, res) { responseContent, parseFloat(process.env.YOUR_MODEL_COST) || 0 ); - await updatePromptTokens(promptId, tokenData); + await Prompt.updatePromptTokens(promptId, tokenData); } res.end(); } @@ -527,7 +533,7 @@ async function handleGenerate(req, res) { responseContent, parseFloat(process.env.YOUR_MODEL_COST) || 0 ); - await updatePromptTokens(promptId, tokenData); + await Prompt.updatePromptTokens(promptId, tokenData); res.status(ollamaResponse.status).json(ollamaResponse.data); }