diff --git a/src/cloudcode/index.js b/src/cloudcode/index.js index 806f402..b1d43f5 100644 --- a/src/cloudcode/index.js +++ b/src/cloudcode/index.js @@ -12,12 +12,12 @@ // Re-export public API export { sendMessage } from './message-handler.js'; export { sendMessageStream } from './streaming-handler.js'; -export { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier } from './model-api.js'; +export { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier, isValidModel } from './model-api.js'; // Default export for backwards compatibility import { sendMessage } from './message-handler.js'; import { sendMessageStream } from './streaming-handler.js'; -import { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier } from './model-api.js'; +import { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier, isValidModel } from './model-api.js'; export default { sendMessage, @@ -25,5 +25,6 @@ export default { listModels, fetchAvailableModels, getModelQuotas, - getSubscriptionTier + getSubscriptionTier, + isValidModel }; diff --git a/src/cloudcode/model-api.js b/src/cloudcode/model-api.js index 0dc6cf4..b567235 100644 --- a/src/cloudcode/model-api.js +++ b/src/cloudcode/model-api.js @@ -9,10 +9,18 @@ import { ANTIGRAVITY_HEADERS, LOAD_CODE_ASSIST_ENDPOINTS, LOAD_CODE_ASSIST_HEADERS, - getModelFamily + getModelFamily, + MODEL_VALIDATION_CACHE_TTL_MS } from '../constants.js'; import { logger } from '../utils/logger.js'; +// Model validation cache +const modelCache = { + validModels: new Set(), + lastFetched: 0, + fetchPromise: null // Prevents concurrent fetches +}; + /** * Check if a model is supported (Claude or Gemini) * @param {string} modelId - Model ID to check @@ -46,6 +54,10 @@ export async function listModels(token) { description: modelData.displayName || modelId })); + // Warm the model validation cache + modelCache.validModels = new Set(modelList.map(m => m.id)); + modelCache.lastFetched = Date.now(); + return { object: 'list', data: modelList @@ -246,3 +258,71 @@ export async function getSubscriptionTier(token) { logger.warn('[CloudCode] Failed to detect subscription tier from all endpoints. Defaulting to free.'); return { tier: 'free', projectId: null }; } + +/** + * Populate the model validation cache + * @param {string} token - OAuth access token + * @param {string} [projectId] - Optional project ID + * @returns {Promise} + */ +async function populateModelCache(token, projectId = null) { + const now = Date.now(); + + // Check if cache is fresh + if (modelCache.validModels.size > 0 && (now - modelCache.lastFetched) < MODEL_VALIDATION_CACHE_TTL_MS) { + return; + } + + // If already fetching, wait for it + if (modelCache.fetchPromise) { + await modelCache.fetchPromise; + return; + } + + // Start fetch + modelCache.fetchPromise = (async () => { + try { + const data = await fetchAvailableModels(token, projectId); + if (data && data.models) { + const validIds = Object.keys(data.models).filter(modelId => isSupportedModel(modelId)); + modelCache.validModels = new Set(validIds); + modelCache.lastFetched = Date.now(); + logger.debug(`[CloudCode] Model cache populated with ${validIds.length} models`); + } + } catch (error) { + logger.warn(`[CloudCode] Failed to populate model cache: ${error.message}`); + // Don't throw - validation should degrade gracefully + } finally { + modelCache.fetchPromise = null; + } + })(); + + await modelCache.fetchPromise; +} + +/** + * Check if a model ID is valid (exists in the available models list) + * Uses a cached model list with TTL-based refresh + * @param {string} modelId - Model ID to validate + * @param {string} token - OAuth access token for cache population + * @param {string} [projectId] - Optional project ID + * @returns {Promise} True if model is valid + */ +export async function isValidModel(modelId, token, projectId = null) { + try { + // Populate cache if needed + await populateModelCache(token, projectId); + + // If cache is populated, validate against it + if (modelCache.validModels.size > 0) { + return modelCache.validModels.has(modelId); + } + + // Cache empty (fetch failed) - fail open, let API validate + return true; + } catch (error) { + logger.debug(`[CloudCode] Model validation error: ${error.message}`); + // Fail open - let the API validate + return true; + } +} diff --git a/src/constants.js b/src/constants.js index 082315b..3a93493 100644 --- a/src/constants.js +++ b/src/constants.js @@ -156,6 +156,9 @@ export const GEMINI_SKIP_SIGNATURE = 'skip_thought_signature_validator'; // Cache TTL for Gemini thoughtSignatures (2 hours) export const GEMINI_SIGNATURE_CACHE_TTL_MS = 2 * 60 * 60 * 1000; +// Cache TTL for model validation (5 minutes) +export const MODEL_VALIDATION_CACHE_TTL_MS = 5 * 60 * 1000; + /** * Get the model family from model name (dynamic detection, no hardcoded list). * @param {string} modelName - The model name from the request @@ -295,6 +298,7 @@ export default { GEMINI_MAX_OUTPUT_TOKENS, GEMINI_SKIP_SIGNATURE, GEMINI_SIGNATURE_CACHE_TTL_MS, + MODEL_VALIDATION_CACHE_TTL_MS, getModelFamily, isThinkingModel, OAUTH_CONFIG, diff --git a/src/server.js b/src/server.js index 5a06cc2..7c9c1b6 100644 --- a/src/server.js +++ b/src/server.js @@ -8,7 +8,7 @@ import express from 'express'; import cors from 'cors'; import path from 'path'; import { fileURLToPath } from 'url'; -import { sendMessage, sendMessageStream, listModels, getModelQuotas, getSubscriptionTier } from './cloudcode/index.js'; +import { sendMessage, sendMessageStream, listModels, getModelQuotas, getSubscriptionTier, isValidModel } from './cloudcode/index.js'; import { mountWebUI } from './webui/index.js'; import { config } from './config.js'; @@ -720,6 +720,18 @@ app.post('/v1/messages', async (req, res) => { const modelId = requestedModel; + // Validate model ID before processing + const { account: validationAccount } = accountManager.selectAccount(); + if (validationAccount) { + const token = await accountManager.getTokenForAccount(validationAccount); + const projectId = validationAccount.subscription?.projectId || null; + const valid = await isValidModel(modelId, token, projectId); + + if (!valid) { + throw new Error(`invalid_request_error: Invalid model: ${modelId}. Use /v1/models to see available models.`); + } + } + // Optimistic Retry: If ALL accounts are rate-limited for this model, reset them to force a fresh check. // If we have some available accounts, we try them first. if (accountManager.isAllRateLimited(modelId)) {