Skip to content

Commit

Permalink
Merge pull request #36 from salesforce/mohith/huggiface-connector
Browse files Browse the repository at this point in the history
feat: Hugging Face recipe for the Open LLM Connector
  • Loading branch information
tatedorman authored Nov 11, 2024
2 parents e8e5559 + f1a0743 commit 7854dad
Show file tree
Hide file tree
Showing 14 changed files with 668 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Dependencies
node_modules/

# Environment variables
.env

# Logs
*.log

# Optional: OS generated files
.DS_Store
Thumbs.db

# Optional: Editor directories and files
.idea/
.vscode/
*.swp
*.swo
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"semi": true,
"trailingComma": "es5",
"singleQuote": true,
"printWidth": 100,
"tabWidth": 4,
"useTabs": false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
web: node index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import dotenv from 'dotenv';

dotenv.config();

const validateEnvironment = () => {
const requiredEnvVars = ['HUGGING_FACE_API_KEY'];

const missingVars = requiredEnvVars.filter((varName) => !process.env[varName]);

if (missingVars.length > 0) {
throw new Error(
`Missing required environment variables: ${missingVars.join(', ')}\n` +
'Please check your .env file or environment configuration.'
);
}
};

// Validate environment variables immediately
validateEnvironment();

export default {
port: process.env.PORT || 3000,
huggingFaceApiKey: process.env.HUGGING_FACE_API_KEY,
huggingFaceApiUrl: 'https://api-inference.huggingface.co/models/',
corsOptions: {
origin: process.env.ALLOWED_ORIGINS ? process.env.ALLOWED_ORIGINS.split(',') : [],
methods: ['POST'],
allowedHeaders: ['Content-Type', 'Authorization'],
},
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import axios from 'axios';
import { v4 as uuidv4 } from 'uuid';
import Joi from 'joi';
import config from '../config/index.js';
import winston from 'winston';

const logger = winston.createLogger({
level: 'info',
format: winston.format.json(),
transports: [
new winston.transports.Console({
format: winston.format.simple(),
}),
],
});

const chatCompletionSchema = Joi.object({
messages: Joi.array()
.items(
Joi.object({
role: Joi.string().valid('system', 'user', 'assistant').required(),
content: Joi.string().required(),
})
)
.min(1)
.required(),
model: Joi.string().required(),
max_tokens: Joi.number().integer().min(1).default(500),
temperature: Joi.number().min(0).max(1),
n: Joi.number().integer().min(1).default(1),
parameters: Joi.object({
top_p: Joi.number().min(0).max(1),
}),
});

export const chatCompletion = async (req, res, next) => {
try {
const { error, value } = chatCompletionSchema.validate(req.body);
if (error) {
return res.status(400).json({ error: error.details[0].message });
}

// Optimize message processing
const systemMessages = [];
const otherMessages = [];
for (const message of value.messages) {
if (message.role === 'system') {
systemMessages.push(message.content);
} else {
otherMessages.push(message);
}
}

const processedMessages = systemMessages.length > 0
? [{ role: 'system', content: systemMessages.join('\n') }, ...otherMessages]
: otherMessages;

const huggingFaceRequestBody = {
model: value.model,
messages: processedMessages,
max_tokens: value.max_tokens,
stream: false,
};

if (value.temperature !== undefined) {
huggingFaceRequestBody.temperature = value.temperature;
}
if (value.parameters && value.parameters.top_p !== undefined) {
huggingFaceRequestBody.top_p = value.parameters.top_p;
}

const response = await axios.post(
`${config.huggingFaceApiUrl}${value.model}/v1/chat/completions`,
huggingFaceRequestBody,
{
headers: {
Authorization: `Bearer ${config.huggingFaceApiKey}`,
'Content-Type': 'application/json',
},
}
);

const reshapedResponse = {
id: uuidv4(),
object: response.data.object,
created: response.data.created,
model: response.data.model,
choices: response.data.choices,
usage: response.data.usage,
};

res.status(200).json(reshapedResponse);
} catch (error) {
logger.error('Error in chat completion:', {
status: error.response?.status,
statusText: error.response?.statusText,
response: error.response?.data,
});
next(error);
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import 'dotenv/config';
import express from 'express';
import cors from 'cors';
import helmet from 'helmet';
import rateLimit from 'express-rate-limit';
import { validateApiKey, errorHandler } from './middleware/index.js';
import config from './config/index.js';
import createSanitizedLogger from './utils/logger.js';
import chatRoutes from './routes/chat.js';

// Create logger with sensitive data filtering
const logger = createSanitizedLogger();

const app = express();

// Middleware
app.use(express.json());

// Helmet configuration with strict security settings
app.use(
helmet({
contentSecurityPolicy: {
directives: {
defaultSrc: ["'self'"],
scriptSrc: ["'self'", "'unsafe-inline'"],
styleSrc: ["'self'", "'unsafe-inline'"],
imgSrc: ["'self'", 'data:', 'https:'],
connectSrc: ["'self'"],
fontSrc: ["'self'"],
objectSrc: ["'none'"],
mediaSrc: ["'self'"],
frameSrc: ["'none'"],
},
},
crossOriginEmbedderPolicy: true,
crossOriginOpenerPolicy: { policy: 'same-origin' },
crossOriginResourcePolicy: { policy: 'same-origin' },
dnsPrefetchControl: { allow: false },
expectCt: { maxAge: 86400, enforce: true },
frameguard: { action: 'deny' },
hsts: { maxAge: 31536000, includeSubDomains: true, preload: true },
ieNoOpen: true,
noSniff: true,
originAgentCluster: true,
permittedCrossDomainPolicies: { permittedPolicies: 'none' },
referrerPolicy: { policy: 'strict-origin-when-cross-origin' },
xssFilter: true,
})
);

app.use(cors(config.corsOptions));

// Rate limiting
const limiter = rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 100, // limit each IP to 100 requests per windowMs
});
app.use(limiter);

// Routes
app.use('/chat', validateApiKey, chatRoutes);

// Error handling middleware
app.use(errorHandler);

const PORT = config.port;
app.listen(PORT, () => {
logger.info(`Server is running on port ${PORT}`);
});

export default app;
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import config from '../config/index.js';
import createSanitizedLogger from '../utils/logger.js';

const logger = createSanitizedLogger();

export const validateApiKey = (req, res, next) => {
const apiKey = req.headers['api-key'];
if (!apiKey || apiKey !== config.huggingFaceApiKey) {
return res.status(401).json({ error: 'Unauthorized: Invalid API key' });
}
next();
};

export const errorHandler = (err, req, res, next) => {
logger.error(`Error name: ${err.name}`);
logger.error(`Error message: ${err.message}`);
logger.error(`Error code: ${err.code}`);
logger.error(`Error stack: ${err.stack}`);
logger.error(`Request: ${req.method} ${req.originalUrl}`);
logger.error(`Request body: ${JSON.stringify(req.body)}`);

const statusCode = err.statusCode || 500;
const message = err.message || 'Internal Server Error';
res.status(statusCode).json({
error: {
status: statusCode,
message: message,
}
});
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"name": "heroku-node-app",
"version": "1.0.0",
"description": "Heroku Node.js app to proxy chat completions",
"main": "index.js",
"scripts": {
"start": "node index.js",
"format:write": "prettier --write ."
},
"type": "module",
"engines": {
"node": "20.x"
},
"dependencies": {
"axios": "^1.2.0",
"cors": "^2.8.5",
"dotenv": "^16.4.5",
"express": "^4.18.2",
"express-rate-limit": "^5.2.6",
"helmet": "^4.6.0",
"joi": "^17.4.0",
"uuid": "^8.3.2",
"winston": "^3.3.3"
},
"devDependencies": {
"@types/express-rate-limit": "^6.0.0",
"prettier": "^3.3.3"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import express from 'express';
import { chatCompletion } from '../controllers/chatController.js';

const router = express.Router();

router.post('/completions', chatCompletion);

export default router;
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import winston from 'winston';

const createSanitizedLogger = (options = {}) => {
const defaultOptions = {
level: 'info',
format: winston.format.combine(
winston.format.simple(),
winston.format.printf(({ level, message }) => {
const sanitizedMessage = message
.replace(/Authorization:.*?(?=\s|$)/gi, 'Authorization: [REDACTED]')
.replace(/api[_-]?key:.*?(?=\s|$)/gi, 'api_key: [REDACTED]')
.replace(/Bearer\s+[A-Za-z0-9-._~+/]+=*/g, 'Bearer [REDACTED]')
.replace(/"token":\s*"[^"]*"/g, '"token": "[REDACTED]"')
.replace(/[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}/g, '[EMAIL REDACTED]');
return `${level}: ${sanitizedMessage}`;
})
),
transports: [new winston.transports.Console()],
};

const mergedOptions = { ...defaultOptions, ...options };

return winston.createLogger(mergedOptions);
};

export default createSanitizedLogger;
9 changes: 9 additions & 0 deletions documentation/cookbook/authors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ tatedorman:
socials:
github: tatedorman

msrivastav13:
name: Mohith
title: Developer Advocate @ Salesforce
url: https://github.com/msrivastav13
image_url: https://github.com/msrivastav13.png
page: true
socials:
github: msrivastav13

# sfyz:
# name: Yuvi
# title: Technical Writer @ Salesforce
Expand Down
Loading

0 comments on commit 7854dad

Please sign in to comment.