from sqlalchemy.orm import Session from db.utils import db_transaction from db.models import File, StagedFileProduct from schemas.file import CreateFileRequest import os from uuid import uuid4 as uuid import logging import csv from io import StringIO from typing import Optional, List logger = logging.getLogger(__name__) # Name,Set code,Set name,Collector number,Foil,Rarity,Quantity,ManaBox ID,Scryfall ID,Purchase price,Misprint,Altered,Condition,Language,Purchase price currency MANABOX_REQUIRED_FILE_HEADERS = ['Name', 'Set code', 'Set name', 'Collector number', 'Foil', 'Rarity', 'Quantity', 'ManaBox ID', 'Scryfall ID', 'Purchase price', 'Misprint', 'Altered', 'Condition', 'Language', 'Purchase price currency'] MANABOX_ALLOWED_FILE_EXTENSIONS = ['.csv'] MANABOX_ALLOWED_FILE_TYPES = ['scan_export'] MANABOX_CONFIG = { "required_headers": MANABOX_REQUIRED_FILE_HEADERS, "allowed_extensions": MANABOX_ALLOWED_FILE_EXTENSIONS, "allowed_types": MANABOX_ALLOWED_FILE_TYPES } SOURCES = { "manabox": MANABOX_CONFIG } TEMP_DIR = os.getcwd() + '/temp/' class FileService: def __init__(self, db: Session): self.db = db # CONFIG def get_config(self, source: str) -> dict: return SOURCES.get(source) # VALIDATION def validate_file_extension(self, filename: str, config: dict) -> bool: return filename.endswith(tuple(config.get("allowed_extensions"))) def validate_file_type(self, metadata: CreateFileRequest, config: dict) -> bool: return metadata.type in config.get("allowed_types") def validate_csv(self, content: bytes, required_headers: Optional[List[str]] = None) -> bool: try: # Try to decode and parse as CSV csv_text = content.decode('utf-8') csv_file = StringIO(csv_text) csv_reader = csv.reader(csv_file) # Check headers if specified headers = next(csv_reader, None) if required_headers and not all(header in headers for header in required_headers): return False return True except (UnicodeDecodeError, csv.Error): return False def validate_file_content(self, content: bytes, metadata: CreateFileRequest, config: dict) -> bool: extension = metadata.filename.split('.')[-1] if extension == 'csv': return self.validate_csv(content, config.get("required_headers")) return False def validate_file(self, content: bytes, metadata: CreateFileRequest) -> bool: # 1. Get config config = self.get_config(metadata.source) # 2. Validate file extension if not self.validate_file_extension(metadata.filename, config): raise Exception("Invalid file extension") # 2. validate file type if not self.validate_file_type(metadata, config): raise Exception("Invalid file type") # 3. Validate file content if not self.validate_file_content(content, metadata, config): raise Exception("Invalid file content") return True # CRUD # CREATE def create_file(self, content: bytes, metadata: CreateFileRequest) -> File: with db_transaction(self.db): file = File( id = str(uuid()), filename = metadata.filename, filepath = TEMP_DIR + metadata.filename, # TODO config variable type = metadata.type, source = metadata.source, filesize_kb = round(len(content) / 1024, 2), status = 'pending', service = metadata.service ) self.db.add(file) with open(file.filepath, 'wb') as f: f.write(content) return file # GET def get_file(self, file_id: str) -> List[File]: file = self.db.query(File).filter(File.id == file_id).first() if not file: raise Exception(f"File with id {file_id} not found") return file def get_files(self, status: Optional[str] = None) -> List[File]: if status: return self.db.query(File).filter(File.status == status).all() return self.db.query(File).all() # DELETE def get_staged_products(self, file_id: str) -> List[StagedFileProduct]: return self.db.query(StagedFileProduct).filter(StagedFileProduct.file_id == file_id).all() def delete_file(self, file_id: str) -> List[File]: file = self.get_file(file_id) if not file: raise Exception(f"File with id {file_id} not found") staged_products = self.get_staged_products(file_id) with db_transaction(self.db): file.status = 'deleted' for staged_product in staged_products: self.db.delete(staged_product) return file