feat: validate model IDs before processing requests
Add model validation cache with 5-minute TTL to reject invalid model IDs upfront instead of sending them to the API. This provides better error messages and avoids unnecessary API calls. - Add MODEL_VALIDATION_CACHE_TTL_MS constant (5 min) - Add isValidModel() with lazy cache population - Warm cache when listModels() is called - Validate model ID in /v1/messages before processing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -12,12 +12,12 @@
|
|||||||
// Re-export public API
|
// Re-export public API
|
||||||
export { sendMessage } from './message-handler.js';
|
export { sendMessage } from './message-handler.js';
|
||||||
export { sendMessageStream } from './streaming-handler.js';
|
export { sendMessageStream } from './streaming-handler.js';
|
||||||
export { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier } from './model-api.js';
|
export { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier, isValidModel } from './model-api.js';
|
||||||
|
|
||||||
// Default export for backwards compatibility
|
// Default export for backwards compatibility
|
||||||
import { sendMessage } from './message-handler.js';
|
import { sendMessage } from './message-handler.js';
|
||||||
import { sendMessageStream } from './streaming-handler.js';
|
import { sendMessageStream } from './streaming-handler.js';
|
||||||
import { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier } from './model-api.js';
|
import { listModels, fetchAvailableModels, getModelQuotas, getSubscriptionTier, isValidModel } from './model-api.js';
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
sendMessage,
|
sendMessage,
|
||||||
@@ -25,5 +25,6 @@ export default {
|
|||||||
listModels,
|
listModels,
|
||||||
fetchAvailableModels,
|
fetchAvailableModels,
|
||||||
getModelQuotas,
|
getModelQuotas,
|
||||||
getSubscriptionTier
|
getSubscriptionTier,
|
||||||
|
isValidModel
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,10 +9,18 @@ import {
|
|||||||
ANTIGRAVITY_HEADERS,
|
ANTIGRAVITY_HEADERS,
|
||||||
LOAD_CODE_ASSIST_ENDPOINTS,
|
LOAD_CODE_ASSIST_ENDPOINTS,
|
||||||
LOAD_CODE_ASSIST_HEADERS,
|
LOAD_CODE_ASSIST_HEADERS,
|
||||||
getModelFamily
|
getModelFamily,
|
||||||
|
MODEL_VALIDATION_CACHE_TTL_MS
|
||||||
} from '../constants.js';
|
} from '../constants.js';
|
||||||
import { logger } from '../utils/logger.js';
|
import { logger } from '../utils/logger.js';
|
||||||
|
|
||||||
|
// Model validation cache
|
||||||
|
const modelCache = {
|
||||||
|
validModels: new Set(),
|
||||||
|
lastFetched: 0,
|
||||||
|
fetchPromise: null // Prevents concurrent fetches
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if a model is supported (Claude or Gemini)
|
* Check if a model is supported (Claude or Gemini)
|
||||||
* @param {string} modelId - Model ID to check
|
* @param {string} modelId - Model ID to check
|
||||||
@@ -46,6 +54,10 @@ export async function listModels(token) {
|
|||||||
description: modelData.displayName || modelId
|
description: modelData.displayName || modelId
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Warm the model validation cache
|
||||||
|
modelCache.validModels = new Set(modelList.map(m => m.id));
|
||||||
|
modelCache.lastFetched = Date.now();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
object: 'list',
|
object: 'list',
|
||||||
data: modelList
|
data: modelList
|
||||||
@@ -246,3 +258,71 @@ export async function getSubscriptionTier(token) {
|
|||||||
logger.warn('[CloudCode] Failed to detect subscription tier from all endpoints. Defaulting to free.');
|
logger.warn('[CloudCode] Failed to detect subscription tier from all endpoints. Defaulting to free.');
|
||||||
return { tier: 'free', projectId: null };
|
return { tier: 'free', projectId: null };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Populate the model validation cache
|
||||||
|
* @param {string} token - OAuth access token
|
||||||
|
* @param {string} [projectId] - Optional project ID
|
||||||
|
* @returns {Promise<void>}
|
||||||
|
*/
|
||||||
|
async function populateModelCache(token, projectId = null) {
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
// Check if cache is fresh
|
||||||
|
if (modelCache.validModels.size > 0 && (now - modelCache.lastFetched) < MODEL_VALIDATION_CACHE_TTL_MS) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If already fetching, wait for it
|
||||||
|
if (modelCache.fetchPromise) {
|
||||||
|
await modelCache.fetchPromise;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start fetch
|
||||||
|
modelCache.fetchPromise = (async () => {
|
||||||
|
try {
|
||||||
|
const data = await fetchAvailableModels(token, projectId);
|
||||||
|
if (data && data.models) {
|
||||||
|
const validIds = Object.keys(data.models).filter(modelId => isSupportedModel(modelId));
|
||||||
|
modelCache.validModels = new Set(validIds);
|
||||||
|
modelCache.lastFetched = Date.now();
|
||||||
|
logger.debug(`[CloudCode] Model cache populated with ${validIds.length} models`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn(`[CloudCode] Failed to populate model cache: ${error.message}`);
|
||||||
|
// Don't throw - validation should degrade gracefully
|
||||||
|
} finally {
|
||||||
|
modelCache.fetchPromise = null;
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
await modelCache.fetchPromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a model ID is valid (exists in the available models list)
|
||||||
|
* Uses a cached model list with TTL-based refresh
|
||||||
|
* @param {string} modelId - Model ID to validate
|
||||||
|
* @param {string} token - OAuth access token for cache population
|
||||||
|
* @param {string} [projectId] - Optional project ID
|
||||||
|
* @returns {Promise<boolean>} True if model is valid
|
||||||
|
*/
|
||||||
|
export async function isValidModel(modelId, token, projectId = null) {
|
||||||
|
try {
|
||||||
|
// Populate cache if needed
|
||||||
|
await populateModelCache(token, projectId);
|
||||||
|
|
||||||
|
// If cache is populated, validate against it
|
||||||
|
if (modelCache.validModels.size > 0) {
|
||||||
|
return modelCache.validModels.has(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache empty (fetch failed) - fail open, let API validate
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
logger.debug(`[CloudCode] Model validation error: ${error.message}`);
|
||||||
|
// Fail open - let the API validate
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -156,6 +156,9 @@ export const GEMINI_SKIP_SIGNATURE = 'skip_thought_signature_validator';
|
|||||||
// Cache TTL for Gemini thoughtSignatures (2 hours)
|
// Cache TTL for Gemini thoughtSignatures (2 hours)
|
||||||
export const GEMINI_SIGNATURE_CACHE_TTL_MS = 2 * 60 * 60 * 1000;
|
export const GEMINI_SIGNATURE_CACHE_TTL_MS = 2 * 60 * 60 * 1000;
|
||||||
|
|
||||||
|
// Cache TTL for model validation (5 minutes)
|
||||||
|
export const MODEL_VALIDATION_CACHE_TTL_MS = 5 * 60 * 1000;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the model family from model name (dynamic detection, no hardcoded list).
|
* Get the model family from model name (dynamic detection, no hardcoded list).
|
||||||
* @param {string} modelName - The model name from the request
|
* @param {string} modelName - The model name from the request
|
||||||
@@ -295,6 +298,7 @@ export default {
|
|||||||
GEMINI_MAX_OUTPUT_TOKENS,
|
GEMINI_MAX_OUTPUT_TOKENS,
|
||||||
GEMINI_SKIP_SIGNATURE,
|
GEMINI_SKIP_SIGNATURE,
|
||||||
GEMINI_SIGNATURE_CACHE_TTL_MS,
|
GEMINI_SIGNATURE_CACHE_TTL_MS,
|
||||||
|
MODEL_VALIDATION_CACHE_TTL_MS,
|
||||||
getModelFamily,
|
getModelFamily,
|
||||||
isThinkingModel,
|
isThinkingModel,
|
||||||
OAUTH_CONFIG,
|
OAUTH_CONFIG,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import express from 'express';
|
|||||||
import cors from 'cors';
|
import cors from 'cors';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import { fileURLToPath } from 'url';
|
import { fileURLToPath } from 'url';
|
||||||
import { sendMessage, sendMessageStream, listModels, getModelQuotas, getSubscriptionTier } from './cloudcode/index.js';
|
import { sendMessage, sendMessageStream, listModels, getModelQuotas, getSubscriptionTier, isValidModel } from './cloudcode/index.js';
|
||||||
import { mountWebUI } from './webui/index.js';
|
import { mountWebUI } from './webui/index.js';
|
||||||
import { config } from './config.js';
|
import { config } from './config.js';
|
||||||
|
|
||||||
@@ -720,6 +720,18 @@ app.post('/v1/messages', async (req, res) => {
|
|||||||
|
|
||||||
const modelId = requestedModel;
|
const modelId = requestedModel;
|
||||||
|
|
||||||
|
// Validate model ID before processing
|
||||||
|
const { account: validationAccount } = accountManager.selectAccount();
|
||||||
|
if (validationAccount) {
|
||||||
|
const token = await accountManager.getTokenForAccount(validationAccount);
|
||||||
|
const projectId = validationAccount.subscription?.projectId || null;
|
||||||
|
const valid = await isValidModel(modelId, token, projectId);
|
||||||
|
|
||||||
|
if (!valid) {
|
||||||
|
throw new Error(`invalid_request_error: Invalid model: ${modelId}. Use /v1/models to see available models.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Optimistic Retry: If ALL accounts are rate-limited for this model, reset them to force a fresh check.
|
// Optimistic Retry: If ALL accounts are rate-limited for this model, reset them to force a fresh check.
|
||||||
// If we have some available accounts, we try them first.
|
// If we have some available accounts, we try them first.
|
||||||
if (accountManager.isAllRateLimited(modelId)) {
|
if (accountManager.isAllRateLimited(modelId)) {
|
||||||
|
|||||||
Reference in New Issue
Block a user