from sqlalchemy.orm import Session from typing import Optional, List, Dict, Any from uuid import uuid4 import csv import logging import os from io import StringIO from db.utils import db_transaction from db.models import File, StagedFileProduct from schemas.file import CreateFileRequest logger = logging.getLogger(__name__) class FileConfig: """Configuration constants for file processing""" TEMP_DIR = os.path.join(os.getcwd(), 'app/' + 'temp') MANABOX_HEADERS = [ 'Name', 'Set code', 'Set name', 'Collector number', 'Foil', 'Rarity', 'Quantity', 'ManaBox ID', 'Scryfall ID', 'Purchase price', 'Misprint', 'Altered', 'Condition', 'Language', 'Purchase price currency' ] SOURCES = { "manabox": { "required_headers": MANABOX_HEADERS, "allowed_extensions": ['.csv'], "allowed_types": ['scan_export_common', 'scan_export_rare'] } } class FileValidationError(Exception): """Custom exception for file validation errors""" pass class FileService: def __init__(self, db: Session): self.db = db def get_config(self, source: str) -> Dict[str, Any]: """Get configuration for a specific source""" config = FileConfig.SOURCES.get(source) if not config: raise FileValidationError(f"Unsupported source: {source}") return config def validate_file_extension(self, filename: str, config: Dict[str, Any]) -> bool: """Validate file extension against allowed extensions""" return any(filename.endswith(ext) for ext in config["allowed_extensions"]) def validate_file_type(self, metadata: CreateFileRequest, config: Dict[str, Any]) -> bool: """Validate file type against allowed types""" return metadata.type in config["allowed_types"] def validate_csv(self, content: bytes, required_headers: Optional[List[str]] = None) -> bool: """Validate CSV content and headers""" try: csv_text = content.decode('utf-8') csv_file = StringIO(csv_text) csv_reader = csv.reader(csv_file) if required_headers: headers = next(csv_reader, None) if not headers or not all(header in headers for header in required_headers): return False return True except (UnicodeDecodeError, csv.Error) as e: logger.error(f"CSV validation error: {str(e)}") return False def validate_file_content(self, content: bytes, metadata: CreateFileRequest, config: Dict[str, Any]) -> bool: """Validate file content based on file type""" extension = os.path.splitext(metadata.filename)[1].lower() if extension == '.csv': return self.validate_csv(content, config.get("required_headers")) return False def validate_file(self, content: bytes, metadata: CreateFileRequest) -> bool: """Validate file against all criteria""" config = self.get_config(metadata.source) if not self.validate_file_extension(metadata.filename, config): raise FileValidationError("Invalid file extension") if not self.validate_file_type(metadata, config): raise FileValidationError("Invalid file type") if not self.validate_file_content(content, metadata, config): raise FileValidationError("Invalid file content or headers") return True def create_file(self, content: bytes, metadata: CreateFileRequest) -> File: """Create a new file record and save the file""" with db_transaction(self.db): file = File( id=str(uuid4()), filename=metadata.filename, filepath=os.path.join(FileConfig.TEMP_DIR, metadata.filename), type=metadata.type, source=metadata.source, filesize_kb=round(len(content) / 1024, 2), status='pending', service=metadata.service ) self.db.add(file) os.makedirs(FileConfig.TEMP_DIR, exist_ok=True) with open(file.filepath, 'wb') as f: f.write(content) return file def get_file(self, file_id: str) -> File: """Get a file by ID""" file = self.db.query(File).filter(File.id == file_id).first() if not file: raise FileValidationError(f"File with id {file_id} not found") return file def get_files(self, status: Optional[str] = None) -> List[File]: """Get all files, optionally filtered by status""" query = self.db.query(File) if status: query = query.filter(File.status == status) return query.all() def get_staged_products(self, file_id: str) -> List[StagedFileProduct]: """Get staged products for a file""" return self.db.query(StagedFileProduct).filter( StagedFileProduct.file_id == file_id ).all() def delete_file(self, file_id: str) -> File: """Mark a file as deleted and remove associated staged products""" file = self.get_file(file_id) staged_products = self.get_staged_products(file_id) with db_transaction(self.db): file.status = 'deleted' for staged_product in staged_products: self.db.delete(staged_product) return file def get_file_content(self, file_id: str) -> bytes: """Get the content of a file""" file = self.get_file(file_id) try: with open(file.filepath, 'rb') as f: return f.read() except IOError as e: logger.error(f"Error reading file {file_id}: {str(e)}") raise FileValidationError(f"Could not read file content for {file_id}")