156 lines
5.8 KiB
Python
156 lines
5.8 KiB
Python
from sqlalchemy.orm import Session
|
|
from typing import Optional, List, Dict, Any
|
|
from uuid import uuid4
|
|
import csv
|
|
import logging
|
|
import os
|
|
from io import StringIO
|
|
|
|
from app.db.utils import db_transaction
|
|
from app.db.models import File, StagedFileProduct
|
|
from app.schemas.file import CreateFileRequest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class FileConfig:
|
|
"""Configuration constants for file processing"""
|
|
TEMP_DIR = os.path.join(os.getcwd(), 'app/' + 'temp')
|
|
|
|
MANABOX_HEADERS = [
|
|
'Name', 'Set code', 'Set name', 'Collector number', 'Foil',
|
|
'Rarity', 'Quantity', 'ManaBox ID', 'Scryfall ID', 'Purchase price',
|
|
'Misprint', 'Altered', 'Condition', 'Language', 'Purchase price currency'
|
|
]
|
|
|
|
SOURCES = {
|
|
"manabox": {
|
|
"required_headers": MANABOX_HEADERS,
|
|
"allowed_extensions": ['.csv'],
|
|
"allowed_types": ['scan_export_common', 'scan_export_rare']
|
|
}
|
|
}
|
|
|
|
class FileValidationError(Exception):
|
|
"""Custom exception for file validation errors"""
|
|
pass
|
|
|
|
class FileService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def get_config(self, source: str) -> Dict[str, Any]:
|
|
"""Get configuration for a specific source"""
|
|
config = FileConfig.SOURCES.get(source)
|
|
if not config:
|
|
raise FileValidationError(f"Unsupported source: {source}")
|
|
return config
|
|
|
|
def validate_file_extension(self, filename: str, config: Dict[str, Any]) -> bool:
|
|
"""Validate file extension against allowed extensions"""
|
|
return any(filename.endswith(ext) for ext in config["allowed_extensions"])
|
|
|
|
def validate_file_type(self, metadata: CreateFileRequest, config: Dict[str, Any]) -> bool:
|
|
"""Validate file type against allowed types"""
|
|
return metadata.type in config["allowed_types"]
|
|
|
|
def validate_csv(self, content: bytes, required_headers: Optional[List[str]] = None) -> bool:
|
|
"""Validate CSV content and headers"""
|
|
try:
|
|
csv_text = content.decode('utf-8')
|
|
csv_file = StringIO(csv_text)
|
|
csv_reader = csv.reader(csv_file)
|
|
|
|
if required_headers:
|
|
headers = next(csv_reader, None)
|
|
if not headers or not all(header in headers for header in required_headers):
|
|
return False
|
|
return True
|
|
|
|
except (UnicodeDecodeError, csv.Error) as e:
|
|
logger.error(f"CSV validation error: {str(e)}")
|
|
return False
|
|
|
|
def validate_file_content(self, content: bytes, metadata: CreateFileRequest, config: Dict[str, Any]) -> bool:
|
|
"""Validate file content based on file type"""
|
|
extension = os.path.splitext(metadata.filename)[1].lower()
|
|
if extension == '.csv':
|
|
return self.validate_csv(content, config.get("required_headers"))
|
|
return False
|
|
|
|
def validate_file(self, content: bytes, metadata: CreateFileRequest) -> bool:
|
|
"""Validate file against all criteria"""
|
|
config = self.get_config(metadata.source)
|
|
|
|
if not self.validate_file_extension(metadata.filename, config):
|
|
raise FileValidationError("Invalid file extension")
|
|
|
|
if not self.validate_file_type(metadata, config):
|
|
raise FileValidationError("Invalid file type")
|
|
|
|
if not self.validate_file_content(content, metadata, config):
|
|
raise FileValidationError("Invalid file content or headers")
|
|
|
|
return True
|
|
|
|
def create_file(self, content: bytes, metadata: CreateFileRequest) -> File:
|
|
"""Create a new file record and save the file"""
|
|
with db_transaction(self.db):
|
|
file = File(
|
|
id=str(uuid4()),
|
|
filename=metadata.filename,
|
|
filepath=os.path.join(FileConfig.TEMP_DIR, metadata.filename),
|
|
type=metadata.type,
|
|
source=metadata.source,
|
|
filesize_kb=round(len(content) / 1024, 2),
|
|
status='pending',
|
|
service=metadata.service
|
|
)
|
|
self.db.add(file)
|
|
|
|
os.makedirs(FileConfig.TEMP_DIR, exist_ok=True)
|
|
with open(file.filepath, 'wb') as f:
|
|
f.write(content)
|
|
|
|
return file
|
|
|
|
def get_file(self, file_id: str) -> File:
|
|
"""Get a file by ID"""
|
|
file = self.db.query(File).filter(File.id == file_id).first()
|
|
if not file:
|
|
raise FileValidationError(f"File with id {file_id} not found")
|
|
return file
|
|
|
|
def get_files(self, status: Optional[str] = None) -> List[File]:
|
|
"""Get all files, optionally filtered by status"""
|
|
query = self.db.query(File)
|
|
if status:
|
|
query = query.filter(File.status == status)
|
|
return query.all()
|
|
|
|
def get_staged_products(self, file_id: str) -> List[StagedFileProduct]:
|
|
"""Get staged products for a file"""
|
|
return self.db.query(StagedFileProduct).filter(
|
|
StagedFileProduct.file_id == file_id
|
|
).all()
|
|
|
|
def delete_file(self, file_id: str) -> File:
|
|
"""Mark a file as deleted and remove associated staged products"""
|
|
file = self.get_file(file_id)
|
|
staged_products = self.get_staged_products(file_id)
|
|
|
|
with db_transaction(self.db):
|
|
file.status = 'deleted'
|
|
for staged_product in staged_products:
|
|
self.db.delete(staged_product)
|
|
|
|
return file
|
|
|
|
def get_file_content(self, file_id: str) -> bytes:
|
|
"""Get the content of a file"""
|
|
file = self.get_file(file_id)
|
|
try:
|
|
with open(file.filepath, 'rb') as f:
|
|
return f.read()
|
|
except IOError as e:
|
|
logger.error(f"Error reading file {file_id}: {str(e)}")
|
|
raise FileValidationError(f"Could not read file content for {file_id}") |