2025-02-07 20:54:55 -05:00

156 lines
5.8 KiB
Python

from sqlalchemy.orm import Session
from typing import Optional, List, Dict, Any
from uuid import uuid4
import csv
import logging
import os
from io import StringIO
from app.db.utils import db_transaction
from app.db.models import File, StagedFileProduct
from app.schemas.file import CreateFileRequest
logger = logging.getLogger(__name__)
class FileConfig:
"""Configuration constants for file processing"""
TEMP_DIR = os.path.join(os.getcwd(), 'app/' + 'temp')
MANABOX_HEADERS = [
'Name', 'Set code', 'Set name', 'Collector number', 'Foil',
'Rarity', 'Quantity', 'ManaBox ID', 'Scryfall ID', 'Purchase price',
'Misprint', 'Altered', 'Condition', 'Language', 'Purchase price currency'
]
SOURCES = {
"manabox": {
"required_headers": MANABOX_HEADERS,
"allowed_extensions": ['.csv'],
"allowed_types": ['scan_export_common', 'scan_export_rare']
}
}
class FileValidationError(Exception):
"""Custom exception for file validation errors"""
pass
class FileService:
def __init__(self, db: Session):
self.db = db
def get_config(self, source: str) -> Dict[str, Any]:
"""Get configuration for a specific source"""
config = FileConfig.SOURCES.get(source)
if not config:
raise FileValidationError(f"Unsupported source: {source}")
return config
def validate_file_extension(self, filename: str, config: Dict[str, Any]) -> bool:
"""Validate file extension against allowed extensions"""
return any(filename.endswith(ext) for ext in config["allowed_extensions"])
def validate_file_type(self, metadata: CreateFileRequest, config: Dict[str, Any]) -> bool:
"""Validate file type against allowed types"""
return metadata.type in config["allowed_types"]
def validate_csv(self, content: bytes, required_headers: Optional[List[str]] = None) -> bool:
"""Validate CSV content and headers"""
try:
csv_text = content.decode('utf-8')
csv_file = StringIO(csv_text)
csv_reader = csv.reader(csv_file)
if required_headers:
headers = next(csv_reader, None)
if not headers or not all(header in headers for header in required_headers):
return False
return True
except (UnicodeDecodeError, csv.Error) as e:
logger.error(f"CSV validation error: {str(e)}")
return False
def validate_file_content(self, content: bytes, metadata: CreateFileRequest, config: Dict[str, Any]) -> bool:
"""Validate file content based on file type"""
extension = os.path.splitext(metadata.filename)[1].lower()
if extension == '.csv':
return self.validate_csv(content, config.get("required_headers"))
return False
def validate_file(self, content: bytes, metadata: CreateFileRequest) -> bool:
"""Validate file against all criteria"""
config = self.get_config(metadata.source)
if not self.validate_file_extension(metadata.filename, config):
raise FileValidationError("Invalid file extension")
if not self.validate_file_type(metadata, config):
raise FileValidationError("Invalid file type")
if not self.validate_file_content(content, metadata, config):
raise FileValidationError("Invalid file content or headers")
return True
def create_file(self, content: bytes, metadata: CreateFileRequest) -> File:
"""Create a new file record and save the file"""
with db_transaction(self.db):
file = File(
id=str(uuid4()),
filename=metadata.filename,
filepath=os.path.join(FileConfig.TEMP_DIR, metadata.filename),
type=metadata.type,
source=metadata.source,
filesize_kb=round(len(content) / 1024, 2),
status='pending',
service=metadata.service
)
self.db.add(file)
os.makedirs(FileConfig.TEMP_DIR, exist_ok=True)
with open(file.filepath, 'wb') as f:
f.write(content)
return file
def get_file(self, file_id: str) -> File:
"""Get a file by ID"""
file = self.db.query(File).filter(File.id == file_id).first()
if not file:
raise FileValidationError(f"File with id {file_id} not found")
return file
def get_files(self, status: Optional[str] = None) -> List[File]:
"""Get all files, optionally filtered by status"""
query = self.db.query(File)
if status:
query = query.filter(File.status == status)
return query.all()
def get_staged_products(self, file_id: str) -> List[StagedFileProduct]:
"""Get staged products for a file"""
return self.db.query(StagedFileProduct).filter(
StagedFileProduct.file_id == file_id
).all()
def delete_file(self, file_id: str) -> File:
"""Mark a file as deleted and remove associated staged products"""
file = self.get_file(file_id)
staged_products = self.get_staged_products(file_id)
with db_transaction(self.db):
file.status = 'deleted'
for staged_product in staged_products:
self.db.delete(staged_product)
return file
def get_file_content(self, file_id: str) -> bytes:
"""Get the content of a file"""
file = self.get_file(file_id)
try:
with open(file.filepath, 'rb') as f:
return f.read()
except IOError as e:
logger.error(f"Error reading file {file_id}: {str(e)}")
raise FileValidationError(f"Could not read file content for {file_id}")