144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
from slowapi.util import get_remote_address
|
|
from slowapi.errors import RateLimitExceeded
|
|
from database import create_tables
|
|
from messages import router as messages_router
|
|
from shared import router as shared_router
|
|
from config import settings
|
|
import logging
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, settings.LOG_LEVEL),
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create rate limiter
|
|
limiter = Limiter(key_func=get_remote_address)
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="ItDontFitGS API",
|
|
description="API for Old School RuneScape plugin - ItDontFitGS",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Add rate limiting
|
|
app.state.limiter = limiter
|
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
|
|
# Add CORS middleware
|
|
if settings.is_production:
|
|
# Production CORS - restrict to specific domains
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["https://yourdomain.com"], # Replace with your actual domain
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "DELETE"],
|
|
allow_headers=["*"],
|
|
)
|
|
else:
|
|
# Development CORS - more permissive
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Add security headers middleware
|
|
@app.middleware("http")
|
|
async def add_security_headers(request: Request, call_next):
|
|
response = await call_next(request)
|
|
|
|
# Add security headers
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
|
|
# Add HSTS header in production
|
|
if settings.is_production:
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
|
|
return response
|
|
|
|
# Add request logging middleware
|
|
@app.middleware("http")
|
|
async def log_requests(request: Request, call_next):
|
|
# Only log detailed info in development
|
|
if not settings.is_production:
|
|
logger.info(f"🔍 {request.method} {request.url}")
|
|
logger.info(f"Headers: {dict(request.headers)}")
|
|
|
|
# Log body for POST requests
|
|
if request.method in ["POST", "PUT", "PATCH"]:
|
|
body = await request.body()
|
|
logger.info(f"Body: {body.decode('utf-8') if body else 'Empty'}")
|
|
|
|
response = await call_next(request)
|
|
|
|
# Log response in development
|
|
if not settings.is_production:
|
|
logger.info(f"Response: {response.status_code}")
|
|
logger.info(f"Response headers: {dict(response.headers)}")
|
|
|
|
# Log response body for debugging in development
|
|
if request.method == "GET" and request.url.path in ["/messages", "/shared"]:
|
|
response_body = b""
|
|
async for chunk in response.body_iterator:
|
|
response_body += chunk
|
|
logger.info(f"Response body: {response_body.decode('utf-8')}")
|
|
# Create new response with the body
|
|
from fastapi.responses import Response
|
|
return Response(
|
|
content=response_body,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.media_type
|
|
)
|
|
else:
|
|
# Production logging - just basic info
|
|
logger.info(f"{request.method} {request.url.path} - {response.status_code}")
|
|
|
|
return response
|
|
|
|
# Include routers
|
|
app.include_router(messages_router)
|
|
app.include_router(shared_router)
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Create database tables on startup"""
|
|
create_tables()
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint"""
|
|
return {
|
|
"message": "ItDontFitGS API is running",
|
|
"version": "1.0.0",
|
|
"endpoints": {
|
|
"messages": "/messages",
|
|
"shared": "/shared"
|
|
}
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy"}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=settings.API_HOST,
|
|
port=settings.API_PORT,
|
|
reload=True
|
|
)
|