giga_tcg/services/file.py
2025-02-05 21:51:22 -05:00

128 lines
4.9 KiB
Python

from sqlalchemy.orm import Session
from db.utils import db_transaction
from db.models import File, StagedFileProduct
from schemas.file import CreateFileRequest
import os
from uuid import uuid4 as uuid
import logging
import csv
from io import StringIO
from typing import Optional, List
logger = logging.getLogger(__name__)
# Name,Set code,Set name,Collector number,Foil,Rarity,Quantity,ManaBox ID,Scryfall ID,Purchase price,Misprint,Altered,Condition,Language,Purchase price currency
MANABOX_REQUIRED_FILE_HEADERS = ['Name', 'Set code', 'Set name', 'Collector number', 'Foil', 'Rarity', 'Quantity', 'ManaBox ID', 'Scryfall ID', 'Purchase price', 'Misprint', 'Altered', 'Condition', 'Language', 'Purchase price currency']
MANABOX_ALLOWED_FILE_EXTENSIONS = ['.csv']
MANABOX_ALLOWED_FILE_TYPES = ['scan_export_common', 'scan_export_rare']
MANABOX_CONFIG = {
"required_headers": MANABOX_REQUIRED_FILE_HEADERS,
"allowed_extensions": MANABOX_ALLOWED_FILE_EXTENSIONS,
"allowed_types": MANABOX_ALLOWED_FILE_TYPES
}
SOURCES = {
"manabox": MANABOX_CONFIG
}
TEMP_DIR = os.getcwd() + '/temp/'
class FileService:
def __init__(self, db: Session):
self.db = db
# CONFIG
def get_config(self, source: str) -> dict:
return SOURCES.get(source)
# VALIDATION
def validate_file_extension(self, filename: str, config: dict) -> bool:
return filename.endswith(tuple(config.get("allowed_extensions")))
def validate_file_type(self, metadata: CreateFileRequest, config: dict) -> bool:
return metadata.type in config.get("allowed_types")
def validate_csv(self, content: bytes, required_headers: Optional[List[str]] = None) -> bool:
try:
# Try to decode and parse as CSV
csv_text = content.decode('utf-8')
csv_file = StringIO(csv_text)
csv_reader = csv.reader(csv_file)
# Check headers if specified
headers = next(csv_reader, None)
if required_headers and not all(header in headers for header in required_headers):
return False
return True
except (UnicodeDecodeError, csv.Error):
return False
def validate_file_content(self, content: bytes, metadata: CreateFileRequest, config: dict) -> bool:
extension = metadata.filename.split('.')[-1]
if extension == 'csv':
return self.validate_csv(content, config.get("required_headers"))
return False
def validate_file(self, content: bytes, metadata: CreateFileRequest) -> bool:
# 1. Get config
config = self.get_config(metadata.source)
# 2. Validate file extension
if not self.validate_file_extension(metadata.filename, config):
raise Exception("Invalid file extension")
# 2. validate file type
if not self.validate_file_type(metadata, config):
raise Exception("Invalid file type")
# 3. Validate file content
if not self.validate_file_content(content, metadata, config):
raise Exception("Invalid file content")
return True
# CRUD
# CREATE
def create_file(self, content: bytes, metadata: CreateFileRequest) -> File:
with db_transaction(self.db):
file = File(
id = str(uuid()),
filename = metadata.filename,
filepath = TEMP_DIR + metadata.filename, # TODO config variable
type = metadata.type,
source = metadata.source,
filesize_kb = round(len(content) / 1024, 2),
status = 'pending',
service = metadata.service
)
self.db.add(file)
with open(file.filepath, 'wb') as f:
f.write(content)
return file
# GET
def get_file(self, file_id: str) -> List[File]:
file = self.db.query(File).filter(File.id == file_id).first()
if not file:
raise Exception(f"File with id {file_id} not found")
return file
def get_files(self, status: Optional[str] = None) -> List[File]:
if status:
return self.db.query(File).filter(File.status == status).all()
return self.db.query(File).all()
# DELETE
def get_staged_products(self, file_id: str) -> List[StagedFileProduct]:
return self.db.query(StagedFileProduct).filter(StagedFileProduct.file_id == file_id).all()
def delete_file(self, file_id: str) -> List[File]:
file = self.get_file(file_id)
if not file:
raise Exception(f"File with id {file_id} not found")
staged_products = self.get_staged_products(file_id)
with db_transaction(self.db):
file.status = 'deleted'
for staged_product in staged_products:
self.db.delete(staged_product)
return file