128 lines
4.8 KiB
Python
128 lines
4.8 KiB
Python
from sqlalchemy.orm import Session
|
|
from db.utils import db_transaction
|
|
from db.models import File, StagedFileProduct
|
|
from schemas.file import CreateFileRequest
|
|
import os
|
|
from uuid import uuid4 as uuid
|
|
import logging
|
|
import csv
|
|
from io import StringIO
|
|
from typing import Optional, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Name,Set code,Set name,Collector number,Foil,Rarity,Quantity,ManaBox ID,Scryfall ID,Purchase price,Misprint,Altered,Condition,Language,Purchase price currency
|
|
MANABOX_REQUIRED_FILE_HEADERS = ['Name', 'Set code', 'Set name', 'Collector number', 'Foil', 'Rarity', 'Quantity', 'ManaBox ID', 'Scryfall ID', 'Purchase price', 'Misprint', 'Altered', 'Condition', 'Language', 'Purchase price currency']
|
|
MANABOX_ALLOWED_FILE_EXTENSIONS = ['.csv']
|
|
MANABOX_ALLOWED_FILE_TYPES = ['scan_export']
|
|
MANABOX_CONFIG = {
|
|
"required_headers": MANABOX_REQUIRED_FILE_HEADERS,
|
|
"allowed_extensions": MANABOX_ALLOWED_FILE_EXTENSIONS,
|
|
"allowed_types": MANABOX_ALLOWED_FILE_TYPES
|
|
}
|
|
SOURCES = {
|
|
"manabox": MANABOX_CONFIG
|
|
}
|
|
TEMP_DIR = os.getcwd() + '/temp/'
|
|
|
|
class FileService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
# CONFIG
|
|
def get_config(self, source: str) -> dict:
|
|
return SOURCES.get(source)
|
|
|
|
# VALIDATION
|
|
def validate_file_extension(self, filename: str, config: dict) -> bool:
|
|
return filename.endswith(tuple(config.get("allowed_extensions")))
|
|
|
|
def validate_file_type(self, metadata: CreateFileRequest, config: dict) -> bool:
|
|
return metadata.type in config.get("allowed_types")
|
|
|
|
def validate_csv(self, content: bytes, required_headers: Optional[List[str]] = None) -> bool:
|
|
try:
|
|
# Try to decode and parse as CSV
|
|
csv_text = content.decode('utf-8')
|
|
csv_file = StringIO(csv_text)
|
|
csv_reader = csv.reader(csv_file)
|
|
|
|
# Check headers if specified
|
|
headers = next(csv_reader, None)
|
|
if required_headers and not all(header in headers for header in required_headers):
|
|
return False
|
|
|
|
return True
|
|
|
|
except (UnicodeDecodeError, csv.Error):
|
|
return False
|
|
|
|
def validate_file_content(self, content: bytes, metadata: CreateFileRequest, config: dict) -> bool:
|
|
extension = metadata.filename.split('.')[-1]
|
|
if extension == 'csv':
|
|
return self.validate_csv(content, config.get("required_headers"))
|
|
return False
|
|
|
|
def validate_file(self, content: bytes, metadata: CreateFileRequest) -> bool:
|
|
# 1. Get config
|
|
config = self.get_config(metadata.source)
|
|
# 2. Validate file extension
|
|
if not self.validate_file_extension(metadata.filename, config):
|
|
raise Exception("Invalid file extension")
|
|
# 2. validate file type
|
|
if not self.validate_file_type(metadata, config):
|
|
raise Exception("Invalid file type")
|
|
# 3. Validate file content
|
|
if not self.validate_file_content(content, metadata, config):
|
|
raise Exception("Invalid file content")
|
|
return True
|
|
|
|
# CRUD
|
|
# CREATE
|
|
def create_file(self, content: bytes, metadata: CreateFileRequest) -> File:
|
|
with db_transaction(self.db):
|
|
file = File(
|
|
id = str(uuid()),
|
|
filename = metadata.filename,
|
|
filepath = TEMP_DIR + metadata.filename, # TODO config variable
|
|
type = metadata.type,
|
|
source = metadata.source,
|
|
filesize_kb = round(len(content) / 1024, 2),
|
|
status = 'pending',
|
|
service = metadata.service
|
|
)
|
|
self.db.add(file)
|
|
with open(file.filepath, 'wb') as f:
|
|
f.write(content)
|
|
return file
|
|
|
|
# GET
|
|
def get_file(self, file_id: str) -> List[File]:
|
|
file = self.db.query(File).filter(File.id == file_id).first()
|
|
if not file:
|
|
raise Exception(f"File with id {file_id} not found")
|
|
return file
|
|
|
|
def get_files(self, status: Optional[str] = None) -> List[File]:
|
|
if status:
|
|
return self.db.query(File).filter(File.status == status).all()
|
|
return self.db.query(File).all()
|
|
|
|
# DELETE
|
|
def get_staged_products(self, file_id: str) -> List[StagedFileProduct]:
|
|
return self.db.query(StagedFileProduct).filter(StagedFileProduct.file_id == file_id).all()
|
|
|
|
def delete_file(self, file_id: str) -> List[File]:
|
|
file = self.get_file(file_id)
|
|
if not file:
|
|
raise Exception(f"File with id {file_id} not found")
|
|
|
|
staged_products = self.get_staged_products(file_id)
|
|
|
|
with db_transaction(self.db):
|
|
file.status = 'deleted'
|
|
for staged_product in staged_products:
|
|
self.db.delete(staged_product)
|
|
|
|
return file
|