first commit
This commit is contained in:
70
.dockerignore
Normal file
70
.dockerignore
Normal file
@@ -0,0 +1,70 @@
|
||||
# Docker ignore file
|
||||
# Files and directories to exclude from Docker build context
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
.env.development
|
||||
.env.production
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
DEPLOYMENT.md
|
||||
*.md
|
||||
|
||||
# Non-Docker deployment files
|
||||
nginx.conf
|
||||
itdontfitgs-api.service
|
||||
start.sh
|
||||
start_production.sh
|
||||
env.production
|
||||
env.docker
|
||||
|
||||
# Test files
|
||||
test_*.py
|
||||
*_test.py
|
||||
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
.env
|
||||
.venv
|
||||
__pycache__
|
||||
41
Dockerfile
Normal file
41
Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Use Python 3.11 slim image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for better caching
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create non-root user for security
|
||||
RUN adduser --disabled-password --gecos '' appuser && \
|
||||
chown -R appuser:appuser /app
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
|
||||
CMD python -c "import requests; requests.get('http://localhost:8000/health')" || exit 1
|
||||
|
||||
# Start the application
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||
45
auth.py
Normal file
45
auth.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import HTTPException, status, Depends, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
def verify_token(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Verify the provided token matches the configured secret token, or allow no token for plugin compatibility"""
|
||||
token = None
|
||||
|
||||
# Try Bearer token first
|
||||
if credentials:
|
||||
token = credentials.credentials
|
||||
logger.info(f"Bearer token received: {token[:10]}...")
|
||||
else:
|
||||
# Try query parameter
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
logger.info(f"Query token received: {token[:10]}...")
|
||||
else:
|
||||
# Try header without Bearer prefix
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and not auth_header.startswith("Bearer "):
|
||||
token = auth_header
|
||||
logger.info(f"Direct auth header received: {token[:10]}...")
|
||||
|
||||
# If no token provided, allow access (authentication is optional)
|
||||
if not token:
|
||||
logger.info("No token provided - allowing access")
|
||||
return None
|
||||
|
||||
# If token provided, verify it
|
||||
if token != settings.SECRET_TOKEN:
|
||||
logger.warning(f"Invalid token provided: {token[:10]}...")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
logger.info("Token verification successful")
|
||||
return token
|
||||
31
config.py
Normal file
31
config.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class Settings:
|
||||
# Database
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "postgresql://username:password@localhost:5432/itdontfitgsapi")
|
||||
|
||||
# Security
|
||||
SECRET_TOKEN: str = os.getenv("SECRET_TOKEN", "your_secret_token_here")
|
||||
|
||||
# Server
|
||||
API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
|
||||
API_PORT: int = int(os.getenv("API_PORT", "8000"))
|
||||
|
||||
# Environment
|
||||
ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
||||
DEBUG: bool = os.getenv("DEBUG", "true").lower() == "true"
|
||||
|
||||
# SSL
|
||||
SSL_REDIRECT: bool = os.getenv("SSL_REDIRECT", "false").lower() == "true"
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
return self.ENVIRONMENT.lower() == "production"
|
||||
|
||||
settings = Settings()
|
||||
43
database.py
Normal file
43
database.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from datetime import datetime
|
||||
from config import settings
|
||||
|
||||
# Create database engine
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message = Column(Text, nullable=False)
|
||||
sender = Column(String(255), nullable=False)
|
||||
item_id = Column(Integer, nullable=True)
|
||||
amount = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
item_id = Column(Integer, nullable=False)
|
||||
item = Column(String(255), nullable=False)
|
||||
user = Column(String(255), nullable=False)
|
||||
amount = Column(Integer, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# Create tables
|
||||
def create_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Dependency to get database session
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
27
docker-compose.yml
Normal file
27
docker-compose.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# FastAPI Application
|
||||
api:
|
||||
build: .
|
||||
container_name: itdontfitgs-api
|
||||
environment:
|
||||
DATABASE_URL: ${DATABASE_URL}
|
||||
SECRET_TOKEN: ${SECRET_TOKEN}
|
||||
API_HOST: 0.0.0.0
|
||||
API_PORT: 8000
|
||||
ENVIRONMENT: production
|
||||
DEBUG: true
|
||||
SSL_REDIRECT: false # Handled by your gateway
|
||||
LOG_LEVEL: INFO
|
||||
ports:
|
||||
- "8082:8000"
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./logs:/app/logs
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import requests; requests.get('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
50
init_db.py
Normal file
50
init_db.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database initialization script for ItDontFitGS API
|
||||
Run this script to create the database tables
|
||||
"""
|
||||
|
||||
from database import create_tables, engine
|
||||
from sqlalchemy import text
|
||||
import sys
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database and create tables"""
|
||||
try:
|
||||
# Test database connection
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(text("SELECT 1"))
|
||||
print("✓ Database connection successful")
|
||||
|
||||
# Create tables
|
||||
create_tables()
|
||||
print("✓ Database tables created successfully")
|
||||
|
||||
# Verify tables exist
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(text("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name IN ('messages', 'transactions')
|
||||
ORDER BY table_name
|
||||
"""))
|
||||
tables = [row[0] for row in result]
|
||||
|
||||
if 'messages' in tables and 'transactions' in tables:
|
||||
print("✓ Tables verified: messages, transactions")
|
||||
else:
|
||||
print("✗ Table verification failed")
|
||||
return False
|
||||
|
||||
print("\n🎉 Database initialization completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Database initialization failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Initializing ItDontFitGS API database...")
|
||||
success = init_database()
|
||||
sys.exit(0 if success else 1)
|
||||
143
main.py
Normal file
143
main.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from database import create_tables
|
||||
from messages import router as messages_router
|
||||
from shared import router as shared_router
|
||||
from config import settings
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.LOG_LEVEL),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create rate limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="ItDontFitGS API",
|
||||
description="API for Old School RuneScape plugin - ItDontFitGS",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add rate limiting
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Add CORS middleware
|
||||
if settings.is_production:
|
||||
# Production CORS - restrict to specific domains
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://yourdomain.com"], # Replace with your actual domain
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "DELETE"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
else:
|
||||
# Development CORS - more permissive
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add security headers middleware
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Add HSTS header in production
|
||||
if settings.is_production:
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
return response
|
||||
|
||||
# Add request logging middleware
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
# Only log detailed info in development
|
||||
if not settings.is_production:
|
||||
logger.info(f"🔍 {request.method} {request.url}")
|
||||
logger.info(f"Headers: {dict(request.headers)}")
|
||||
|
||||
# Log body for POST requests
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
body = await request.body()
|
||||
logger.info(f"Body: {body.decode('utf-8') if body else 'Empty'}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Log response in development
|
||||
if not settings.is_production:
|
||||
logger.info(f"Response: {response.status_code}")
|
||||
logger.info(f"Response headers: {dict(response.headers)}")
|
||||
|
||||
# Log response body for debugging in development
|
||||
if request.method == "GET" and request.url.path in ["/messages", "/shared"]:
|
||||
response_body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
response_body += chunk
|
||||
logger.info(f"Response body: {response_body.decode('utf-8')}")
|
||||
# Create new response with the body
|
||||
from fastapi.responses import Response
|
||||
return Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
)
|
||||
else:
|
||||
# Production logging - just basic info
|
||||
logger.info(f"{request.method} {request.url.path} - {response.status_code}")
|
||||
|
||||
return response
|
||||
|
||||
# Include routers
|
||||
app.include_router(messages_router)
|
||||
app.include_router(shared_router)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Create database tables on startup"""
|
||||
create_tables()
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"message": "ItDontFitGS API is running",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"messages": "/messages",
|
||||
"shared": "/shared"
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=settings.API_HOST,
|
||||
port=settings.API_PORT,
|
||||
reload=True
|
||||
)
|
||||
65
messages.py
Normal file
65
messages.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from database import get_db, Message
|
||||
from schemas import MessageCreate, Message as MessageSchema
|
||||
from auth import verify_token
|
||||
|
||||
# Create rate limiter instance
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
|
||||
@router.get("", response_model=List[MessageSchema])
|
||||
@router.get("/", response_model=List[MessageSchema])
|
||||
@limiter.limit("30/minute")
|
||||
def get_messages(
|
||||
request: Request,
|
||||
token: str = Depends(verify_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all messages"""
|
||||
messages = db.query(Message).order_by(Message.created_at.asc()).all()
|
||||
return messages
|
||||
|
||||
@router.post("", response_model=MessageSchema)
|
||||
@router.post("/", response_model=MessageSchema)
|
||||
@limiter.limit("10/minute")
|
||||
def create_message(
|
||||
request: Request,
|
||||
message: MessageCreate,
|
||||
token: str = Depends(verify_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a new message"""
|
||||
db_message = Message(
|
||||
message=message.message,
|
||||
sender=message.sender,
|
||||
item_id=message.item_id,
|
||||
amount=message.amount
|
||||
)
|
||||
db.add(db_message)
|
||||
db.commit()
|
||||
db.refresh(db_message)
|
||||
return db_message
|
||||
|
||||
@router.delete("/{message_id}")
|
||||
@limiter.limit("20/minute")
|
||||
def delete_message(
|
||||
request: Request,
|
||||
message_id: int,
|
||||
token: str = Depends(verify_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Delete a message by ID"""
|
||||
message = db.query(Message).filter(Message.id == message_id).first()
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Message not found"
|
||||
)
|
||||
db.delete(message)
|
||||
db.commit()
|
||||
return {"message": "Message deleted successfully"}
|
||||
40
recreate_tables.py
Normal file
40
recreate_tables.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to recreate database tables with updated schema
|
||||
"""
|
||||
|
||||
from database import engine, Base
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def recreate_tables():
|
||||
"""Drop and recreate all tables"""
|
||||
try:
|
||||
with engine.connect() as connection:
|
||||
# Drop existing tables
|
||||
logger.info("Dropping existing tables...")
|
||||
connection.execute(text("DROP TABLE IF EXISTS messages CASCADE"))
|
||||
connection.execute(text("DROP TABLE IF EXISTS transactions CASCADE"))
|
||||
connection.commit()
|
||||
logger.info("✓ Tables dropped")
|
||||
|
||||
# Create new tables with updated schema
|
||||
logger.info("Creating new tables...")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("✓ Tables created with updated schema")
|
||||
|
||||
print("\n🎉 Database tables recreated successfully!")
|
||||
print("The plugin should now work with optional item_id and amount fields.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recreating tables: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Recreating database tables with updated schema...")
|
||||
recreate_tables()
|
||||
26
requirements.txt
Normal file
26
requirements.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
alembic==1.16.5
|
||||
annotated-types==0.7.0
|
||||
anyio==4.11.0
|
||||
click==8.3.0
|
||||
Deprecated==1.2.18
|
||||
fastapi==0.117.1
|
||||
greenlet==3.2.4
|
||||
h11==0.16.0
|
||||
idna==3.10
|
||||
limits==5.5.0
|
||||
Mako==1.3.10
|
||||
MarkupSafe==3.0.2
|
||||
packaging==25.0
|
||||
psycopg2-binary==2.9.10
|
||||
pydantic==2.11.9
|
||||
pydantic_core==2.33.2
|
||||
python-dotenv==1.1.1
|
||||
python-multipart==0.0.20
|
||||
slowapi==0.1.9
|
||||
sniffio==1.3.1
|
||||
SQLAlchemy==2.0.43
|
||||
starlette==0.48.0
|
||||
typing-inspection==0.4.1
|
||||
typing_extensions==4.15.0
|
||||
uvicorn==0.37.0
|
||||
wrapt==1.17.3
|
||||
97
schemas.py
Normal file
97
schemas.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from pydantic import BaseModel, field_serializer, field_validator, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import re
|
||||
import html
|
||||
|
||||
# Message schemas
|
||||
class MessageBase(BaseModel):
|
||||
message: str = Field(..., min_length=1, max_length=1000, description="Message content")
|
||||
sender: str = Field(..., min_length=1, max_length=255, description="Sender name")
|
||||
item_id: Optional[int] = Field(None, ge=0, description="Item ID")
|
||||
amount: Optional[int] = Field(None, ge=0, description="Item amount")
|
||||
|
||||
# No validators on base class - they will be added to Create classes only
|
||||
|
||||
class MessageCreate(MessageBase):
|
||||
@field_validator('message', 'sender')
|
||||
@classmethod
|
||||
def sanitize_text(cls, v):
|
||||
"""Sanitize text input to prevent XSS and other attacks"""
|
||||
if not v:
|
||||
return v
|
||||
# HTML escape to prevent XSS
|
||||
v = html.escape(v.strip())
|
||||
# Remove any remaining script tags or dangerous patterns
|
||||
v = re.sub(r'<script.*?</script>', '', v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r'javascript:', '', v, flags=re.IGNORECASE)
|
||||
return v
|
||||
|
||||
@field_validator('sender')
|
||||
@classmethod
|
||||
def validate_sender(cls, v):
|
||||
"""Validate sender name format"""
|
||||
if not v:
|
||||
return v
|
||||
# Allow only alphanumeric, spaces, and common characters
|
||||
if not re.match(r'^[a-zA-Z0-9\s\-_\.]+$', v):
|
||||
raise ValueError('Sender name contains invalid characters')
|
||||
return v
|
||||
|
||||
class Message(MessageBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
@field_serializer('created_at')
|
||||
def serialize_created_at(self, value: datetime) -> str:
|
||||
"""Format datetime to match plugin expectations: 'u-M-d H:m:s' (no leading zeros)"""
|
||||
return value.strftime("%Y-%-m-%-d %-H:%M:%S")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
# Transaction schemas
|
||||
class TransactionBase(BaseModel):
|
||||
item_id: int = Field(..., ge=0, description="Item ID")
|
||||
item: str = Field(..., min_length=1, max_length=255, description="Item name")
|
||||
user: str = Field(..., min_length=1, max_length=255, description="User name")
|
||||
amount: int = Field(..., description="Transaction amount")
|
||||
|
||||
# No validators on base class - they will be added to Create classes only
|
||||
|
||||
class TransactionCreate(TransactionBase):
|
||||
@field_validator('item', 'user')
|
||||
@classmethod
|
||||
def sanitize_text(cls, v):
|
||||
"""Sanitize text input to prevent XSS and other attacks"""
|
||||
if not v:
|
||||
return v
|
||||
# HTML escape to prevent XSS
|
||||
v = html.escape(v.strip())
|
||||
# Remove any remaining script tags or dangerous patterns
|
||||
v = re.sub(r'<script.*?</script>', '', v, flags=re.IGNORECASE | re.DOTALL)
|
||||
v = re.sub(r'javascript:', '', v, flags=re.IGNORECASE)
|
||||
return v
|
||||
|
||||
@field_validator('user')
|
||||
@classmethod
|
||||
def validate_user(cls, v):
|
||||
"""Validate user name format"""
|
||||
if not v:
|
||||
return v
|
||||
# Allow only alphanumeric, spaces, and common characters
|
||||
if not re.match(r'^[a-zA-Z0-9\s\-_\.]+$', v):
|
||||
raise ValueError('User name contains invalid characters')
|
||||
return v
|
||||
|
||||
class Transaction(TransactionBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
@field_serializer('created_at')
|
||||
def serialize_created_at(self, value: datetime) -> str:
|
||||
"""Format datetime to match plugin expectations: 'u-M-d H:m:s' (no leading zeros)"""
|
||||
return value.strftime("%Y-%-m-%-d %-H:%M:%S")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
74
shared.py
Normal file
74
shared.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from typing import List, Union, Optional
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from database import get_db, Transaction
|
||||
from schemas import TransactionCreate, Transaction as TransactionSchema
|
||||
from auth import verify_token
|
||||
import json
|
||||
|
||||
# Create rate limiter instance
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
router = APIRouter(prefix="/shared", tags=["shared"])
|
||||
|
||||
@router.get("", response_model=List[TransactionSchema])
|
||||
@router.get("/", response_model=List[TransactionSchema])
|
||||
@limiter.limit("30/minute")
|
||||
def get_transactions(
|
||||
request: Request,
|
||||
search: Optional[str] = Query(None, description="Search term for item names"),
|
||||
token: str = Depends(verify_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get all bank transactions, optionally filtered by search term"""
|
||||
query = db.query(Transaction)
|
||||
|
||||
# Apply search filter if provided
|
||||
if search and search.strip():
|
||||
search_term = f"%{search.strip().lower()}%"
|
||||
query = query.filter(func.lower(Transaction.item).like(search_term))
|
||||
|
||||
transactions = query.order_by(Transaction.created_at.asc()).all()
|
||||
return transactions
|
||||
|
||||
@router.post("", response_model=List[TransactionSchema])
|
||||
@router.post("/", response_model=List[TransactionSchema])
|
||||
@limiter.limit("20/minute")
|
||||
async def create_transaction(
|
||||
request: Request,
|
||||
token: str = Depends(verify_token),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create one or more bank transactions"""
|
||||
# Get raw JSON data
|
||||
body = await request.body()
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
|
||||
# Handle both single transaction and array of transactions
|
||||
if isinstance(data, dict):
|
||||
transactions_data = [data]
|
||||
else:
|
||||
transactions_data = data
|
||||
|
||||
created_transactions = []
|
||||
|
||||
for transaction_data in transactions_data:
|
||||
db_transaction = Transaction(
|
||||
item_id=transaction_data.get('item_id'),
|
||||
item=transaction_data.get('item'),
|
||||
user=transaction_data.get('user'),
|
||||
amount=transaction_data.get('amount')
|
||||
)
|
||||
db.add(db_transaction)
|
||||
created_transactions.append(db_transaction)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Refresh all transactions
|
||||
for transaction in created_transactions:
|
||||
db.refresh(transaction)
|
||||
|
||||
return created_transactions
|
||||
Reference in New Issue
Block a user