diff --git a/.gitignore b/.gitignore index 8764a20..68bd868 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.db __pycache__ .venv -*.sqlite3 \ No newline at end of file +*.sqlite3 +*.log \ No newline at end of file diff --git a/scraper/app.py b/scraper/app.py index 122126e..7ae6679 100644 --- a/scraper/app.py +++ b/scraper/app.py @@ -1,103 +1,247 @@ -import threading -import time -from datetime import datetime +from datetime import datetime, timedelta import requests -from models import Submission -import logging +from models import Post +import praw +from zoneinfo import ZoneInfo +from exceptions import InvalidMethodError, InvalidDataTypeError, APIRequestError +from app_log import LoggingManager +from threads import Scheduler, ThreadManager -# logging -logging.basicConfig(level=logging.INFO) -class Application: - def __init__(self, reddit_monitor, webhook_notifier, api_url): - self.reddit_monitor = reddit_monitor - self.webhook_notifier = webhook_notifier +class ApiRequestHandler: + def __init__(self, api_url: str): self.api_url = api_url - - def send_api_request(self, method, url, data=None, params=None): - response = requests.request(method, url, data=data, params=params) + self.log_manager = LoggingManager("scraper.log") + + def send_api_request( + self, method: str, api_url: str, data=None, params=None + ) -> dict: + if method not in ["GET", "POST", "PUT", "DELETE"]: + raise InvalidMethodError(f"Invalid method: {method}") + if data is not None and not isinstance(data, dict): + raise InvalidDataTypeError(f"Invalid data type: {type(data)} expected dict") + if params is not None and not isinstance(params, dict): + raise InvalidDataTypeError( + f"Invalid data type: {type(params)} expected dict" + ) + response = requests.request(method, api_url, data=data, params=params) + success_codes = [200, 201, 204] + if response.status_code not in success_codes: + self.log_manager.error( + f"API request failed: {response.status_code} - {response.text}" + ) + raise APIRequestError(response.status_code, response.text) return response.json() - - def get_submission_by_reddit_id(self, reddit_id): - logging.info(f"Getting submission by reddit_id: {reddit_id}") - logging.info(f"{self.api_url}submissions/?reddit_id={reddit_id}") - response = self.send_api_request("GET", f"{self.api_url}submissions/?reddit_id={reddit_id}") - logging.info(response) + + +class PostManager: + def __init__(self, api_request_handler: ApiRequestHandler): + self.api_request_handler = api_request_handler + self.log_manager = LoggingManager("scraper.log") + + def get_post_by_reddit_id(self, reddit_id: str) -> dict: + self.log_manager.log(f"Getting post by reddit id: {reddit_id}") + response = self.api_request_handler.send_api_request( + "GET", f"{self.api_request_handler.api_url}posts/?reddit_id={reddit_id}" + ) return response - - def submission_exists(self, reddit_id): - response = self.get_submission_by_reddit_id(reddit_id) + + def post_exists(self, reddit_id: str) -> bool: + self.log_manager.log(f"Checking if post exists: {reddit_id}") + response = self.get_post_by_reddit_id(reddit_id) if len(response) == 0: - logging.info(f"Submission {reddit_id} does not exist") return False return True - - def update_submission_analytics(self, submission): - submission_id = self.get_submission_by_reddit_id(submission.reddit_id) - logging.info(submission_id) - submission_id = submission_id[0]["id"] + + def insert_post(self, post) -> dict: + self.log_manager.log(f"Inserting post: {post.reddit_id}") + self.post = post data = { - "id": submission_id, - "score": submission.score, - "num_comments": submission.num_comments, + "reddit_id": self.post.reddit_id, + "title": self.post.title, + "name": self.post.name, + "url": self.post.url, + "created_utc": self.post.created_utc, + "selftext": self.post.selftext, + "permalink": self.post.permalink, } - self.send_api_request("PATCH", f"{self.api_url}submissions/{submission_id}/", data=data) - - def get_submissions_to_update(self): - submissions_to_update = self.send_api_request("GET", f"{self.api_url}submissions/?last_7_days=1") - return submissions_to_update - - def insert_submission(self, submission): + response = self.api_request_handler.send_api_request( + "POST", f"{self.api_request_handler.api_url}posts/", data=data + ) + return response + + def get_posts_from_last_7_days(self) -> dict: + self.log_manager.log("Getting posts from last 7 days") + posts_from_last_7_days = self.api_request_handler.send_api_request( + "GET", f"{self.api_request_handler.api_url}posts/?last_7_days=1" + ) + return posts_from_last_7_days + + +class PostAnalyticsManager: + def __init__( + self, api_request_handler: ApiRequestHandler, post_manager: PostManager + ): + self.api_request_handler = api_request_handler + self.post_manager = post_manager + self.log_manager = LoggingManager("scraper.log") + + def check_update_requirements(self, reddit_id: str) -> bool: + self.log_manager.log(f"Checking update requirements for {reddit_id}") + + # Specify your desired timezone, e.g., UTC + timezone = ZoneInfo("UTC") + + # Make your datetime objects timezone-aware + fifteen_minutes_ago = datetime.now(timezone) - timedelta(minutes=15) + now = datetime.now(timezone) + + # Format datetime objects for the API request + time_begin_str = fifteen_minutes_ago.isoformat(timespec="seconds") + time_end_str = now.isoformat(timespec="seconds") + + post_id = self.post_manager.get_post_by_reddit_id(reddit_id) + post_id = post_id[0]["id"] + self.log_manager.log( + f"{self.api_request_handler.api_url}post_analytics/?post={post_id}&time_begin={time_begin_str}&time_end={time_end_str}" + ) + + response = self.api_request_handler.send_api_request( + "GET", + f"{self.api_request_handler.api_url}post_analytics/?post={post_id}&time_begin={time_begin_str}&time_end={time_end_str}", + ) + + if len(response) > 0: + # post should not be updated + return False + + # post should be updated + return True + + def update_post_analytics(self, post: Post) -> dict: + self.log_manager.log(f"Updating post analytics for {post.reddit_id}") + post_id = self.post_manager.get_post_by_reddit_id(post.reddit_id) + post_id = post_id[0]["id"] data = { - "reddit_id": submission.reddit_id, - "title": submission.title, - "name": submission.name, - "url": submission.url, - "created_utc": submission.created_utc, - "selftext": submission.selftext, - "permalink": submission.permalink, - "upvote_ratio": submission.upvote_ratio, + "post": post_id, + "score": post.score, + "num_comments": post.num_comments, + "upvote_ratio": post.upvote_ratio, } - response = self.send_api_request("POST", f"{self.api_url}submissions/", data=data) - logging.info("Inserting submission") - logging.info(response) + response = self.api_request_handler.send_api_request( + "POST", f"{self.api_request_handler.api_url}post_analytics/", data=data + ) + return response + + +class RedditMonitor: + def __init__( + self, client_id, client_secret, user_agent, username, password, subreddit_name + ): + self.reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + username=username, + password=password, + ) + self.subreddit = self.reddit.subreddit(subreddit_name) + self.log_manager = LoggingManager("scraper.log") + + def stream_submissions(self): + self.log_manager.info("Starting submission stream") + for submission in self.subreddit.stream.submissions(): + yield submission + + def update_submissions(self, posts_to_update): + self.log_manager.info("Updating submissions") + for post in posts_to_update: + submission = self.reddit.submission(id=post["reddit_id"]) + yield submission + + +class SubmissionManager: + def __init__( + self, + reddit_monitor: RedditMonitor, + post_manager: PostManager, + post_analytics_manager: PostAnalyticsManager, + WebhookNotifier, + ): + self.reddit_monitor = reddit_monitor + self.post_manager = post_manager + self.post_analytics_manager = post_analytics_manager + self.webhook_notifier = WebhookNotifier + self.log_manager = LoggingManager("scraper.log") + + def convert_submission_to_post(self, submission): + post = Post( + reddit_id=submission.id, + title=submission.title, + name=submission.name, + url=submission.url, + score=submission.score, + num_comments=submission.num_comments, + created_utc=submission.created_utc, + selftext=submission.selftext, + permalink=submission.permalink, + upvote_ratio=submission.upvote_ratio, + ) + return post def process_submissions(self, submissions): for submission in submissions: - submission = Submission( - reddit_id=submission.id, - title=submission.title, - name=submission.name, - url=submission.url, - score=submission.score, - num_comments=submission.num_comments, - created_utc=submission.created_utc, - selftext=submission.selftext, - permalink=submission.permalink, - upvote_ratio=submission.upvote_ratio - ) - if self.submission_exists(submission.reddit_id): - self.update_submission_analytics(submission) + self.log_manager.log(submission) + if self.post_manager.post_exists(submission.id): + self.log_manager.log("Post exists") + self.log_manager.log(f"post id: {submission.id}") + if self.post_analytics_manager.check_update_requirements(submission.id): + self.log_manager.log("Update requirements met") + post = self.convert_submission_to_post(submission) + self.post_analytics_manager.update_post_analytics(post) else: - self.insert_submission(submission) - self.update_submission_analytics(submission) - self.webhook_notifier.send_notification(submission) - + post = self.convert_submission_to_post(submission) + self.post_manager.insert_post(post) + self.post_analytics_manager.update_post_analytics(post) + self.webhook_notifier.send_notification(post) + + +class Application: + def __init__( + self, + reddit_monitor, + webhook_notifier, + api_conn, + post_manager, + post_analytics_manager, + submission_manager, + ): + self.reddit_monitor = reddit_monitor + self.webhook_notifier = webhook_notifier + self.api_conn = api_conn + self.post_manager = post_manager + self.post_analytics_manager = post_analytics_manager + self.log_manager = LoggingManager("scraper.log") + self.submission_manager = submission_manager + self.scheduler = None + self.thread_manager = None + def periodic_update(self): - to_be_updated = self.get_submissions_to_update() + self.log_manager.info("Running periodic update") + to_be_updated = self.post_manager.get_posts_from_last_7_days() submissions = self.reddit_monitor.update_submissions(to_be_updated) - self.process_submissions(submissions) - - def run_periodic_update(self, interval=3600): - while True: - self.periodic_update() - print(f"Existing posts Updated at {datetime.now()}") - time.sleep(interval) + self.submission_manager.process_submissions(submissions) + + def run_periodic_update(self, interval): + self.scheduler = Scheduler(interval, self.periodic_update) + self.scheduler.run() def run(self): - #update_frequency = 3600 # 3600 - #update_thread = threading.Thread(target=self.run_periodic_update, args=(update_frequency, )) - #update_thread.daemon = True - #update_thread.start() + self.log_manager.info("Application started") + update_frequency = 60 * 15 # 15 minutes in seconds + self.thread_manager = ThreadManager( + target=self.run_periodic_update, args=(update_frequency,) + ) + self.thread_manager.run() submissions = self.reddit_monitor.stream_submissions() - self.process_submissions(submissions) \ No newline at end of file + self.submission_manager.process_submissions(submissions) diff --git a/scraper/app_log.py b/scraper/app_log.py new file mode 100644 index 0000000..59e4820 --- /dev/null +++ b/scraper/app_log.py @@ -0,0 +1,46 @@ +import logging +from logging.handlers import RotatingFileHandler +import sys + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class LoggingManager(metaclass=SingletonMeta): + def __init__(self, log_file): + if not hasattr(self, "logger"): + self.log_file = log_file + self.logger = logging.getLogger("scraper") + self.logger.setLevel(logging.DEBUG) + + file_handler = RotatingFileHandler( + self.log_file, maxBytes=1024 * 1024 * 5, backupCount=5 + ) + file_handler.setLevel(logging.DEBUG) + + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + stream_handler.setFormatter(formatter) + + self.logger.addHandler(file_handler) + self.logger.addHandler(stream_handler) + + def log(self, message): + self.logger.debug(message) + + def error(self, message): + self.logger.error(message) + + def info(self, message): + self.logger.info(message) diff --git a/scraper/config.py b/scraper/config.py index 0ea6c92..ee15e4a 100644 --- a/scraper/config.py +++ b/scraper/config.py @@ -7,8 +7,8 @@ class Config: PRAW_USERNAME = os.getenv("PRAW_USERNAME") PRAW_PASSWORD = os.getenv("PRAW_PASSWORD") POKEMANS_WEBHOOK_URL = os.getenv("POKEMANS_WEBHOOK_URL") - PKMN_ENV = 'dev' # os.getenv("PKMN_ENV") + PKMN_ENV = "dev" # os.getenv("PKMN_ENV") SUBREDDIT_NAME = "pkmntcgdeals" USER_AGENT = "praw:zman.video_repost_bot:v0.1.0 (by u/jzman21)" DISABLE_WEBHOOK = False - API_URL = "http://server:8000/api/" \ No newline at end of file + API_URL = "http://server:8000/api/" diff --git a/scraper/exceptions.py b/scraper/exceptions.py new file mode 100644 index 0000000..a1d4e38 --- /dev/null +++ b/scraper/exceptions.py @@ -0,0 +1,19 @@ +class InvalidMethodError(Exception): + """Exception raised for unsupported HTTP methods.""" + + pass + + +class InvalidDataTypeError(Exception): + """Exception raised for unsupported data types.""" + + pass + + +class APIRequestError(Exception): + """Exception raised for API request errors.""" + + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__(f"API Request Failed: {status_code} - {message}") diff --git a/scraper/main.py b/scraper/main.py index 714c48d..62ebe17 100644 --- a/scraper/main.py +++ b/scraper/main.py @@ -1,11 +1,18 @@ -from reddit_monitor import RedditMonitor from webhook import WebhookNotifier -from app import Application +from app import ( + Application, + RedditMonitor, + ApiRequestHandler, + PostManager, + PostAnalyticsManager, + SubmissionManager, +) from config import Config -import logging +from app_log import LoggingManager if __name__ == "__main__": + log_manager = LoggingManager("scraper.log") client_id = Config.PRAW_CLIENT_ID client_secret = Config.PRAW_CLIENT_SECRET user_agent = Config.USER_AGENT @@ -17,21 +24,28 @@ if __name__ == "__main__": pkmn_env = Config.PKMN_ENV api_url = Config.API_URL - # logging - logging.basicConfig(filename='scraper.log', level=logging.DEBUG) - logging.info('Starting scraper') - - reddit_monitor = RedditMonitor(client_id, client_secret, user_agent, username, password, subreddit_name) + reddit_monitor = RedditMonitor( + client_id, client_secret, user_agent, username, password, subreddit_name + ) webhook_notifier = WebhookNotifier(discord_webhook_url, disable_webhook) - app = Application(reddit_monitor, webhook_notifier, api_url) + api_conn = ApiRequestHandler(api_url) + post_manager = PostManager(api_conn) + post_analytics_manager = PostAnalyticsManager(api_conn, post_manager) + submission_manager = SubmissionManager( + reddit_monitor, post_manager, post_analytics_manager, webhook_notifier + ) + app = Application( + reddit_monitor, + webhook_notifier, + api_conn, + post_manager, + post_analytics_manager, + submission_manager, + ) app.run() """ TODO: -- django rest framework -- api for managing database -- remove scraper models -- connect scraper to django rest framework api - pull upvote ration into analytics? - sqlite vs postgres figure out - basic front end (react) @@ -44,4 +58,4 @@ TODO: - try to identify platform ie. costco for gift card, tiktok for coupons, etc. - support for craigslist, ebay, etc. - front end - vizualization, classification, lookup, etc. -""" \ No newline at end of file +""" diff --git a/scraper/models.py b/scraper/models.py index 58b8c9d..61efb83 100644 --- a/scraper/models.py +++ b/scraper/models.py @@ -1,5 +1,17 @@ -class Submission(): - def __init__(self, reddit_id, title, name, url, score, num_comments, created_utc, selftext, permalink, upvote_ratio): +class Post: + def __init__( + self, + reddit_id, + title, + name, + url, + score, + num_comments, + created_utc, + selftext, + permalink, + upvote_ratio, + ): self.reddit_id = reddit_id self.title = title self.name = name @@ -10,6 +22,6 @@ class Submission(): self.selftext = selftext self.permalink = permalink self.upvote_ratio = upvote_ratio - + def __str__(self): - return f"{self.reddit_id} {self.title} {self.name} {self.url} {self.score} {self.num_comments} {self.created_utc} {self.selftext} {self.permalink} {self.upvote_ratio}" \ No newline at end of file + return f"{self.reddit_id} {self.title} {self.name} {self.url} {self.score} {self.num_comments} {self.created_utc} {self.selftext} {self.permalink} {self.upvote_ratio}" diff --git a/scraper/reddit_monitor.py b/scraper/reddit_monitor.py deleted file mode 100644 index a5e4ae0..0000000 --- a/scraper/reddit_monitor.py +++ /dev/null @@ -1,23 +0,0 @@ -import praw -from datetime import datetime, timedelta - - -class RedditMonitor: - def __init__(self, client_id, client_secret, user_agent, username, password, subreddit_name): - self.reddit = praw.Reddit( - client_id=client_id, - client_secret=client_secret, - user_agent=user_agent, - username=username, - password=password - ) - self.subreddit = self.reddit.subreddit(subreddit_name) - - def stream_submissions(self): - for submission in self.subreddit.stream.submissions(): - yield submission - - def update_submissions(self, submissions_to_update): - for submission in submissions_to_update: - praw_submission = self.reddit.submission(id=submission['reddit_id']) - yield praw_submission \ No newline at end of file diff --git a/scraper/threads.py b/scraper/threads.py new file mode 100644 index 0000000..bf2d44b --- /dev/null +++ b/scraper/threads.py @@ -0,0 +1,26 @@ +import threading + + +class Scheduler: + def __init__(self, interval, function): + self.interval = interval + self.function = function + self.stop_event = threading.Event() + + def run(self): + while not self.stop_event.wait(self.interval): + self.function() + + def stop(self): + self.stop_event.set() + + +class ThreadManager: + def __init__(self, target, args: tuple = ()) -> None: + self.target = target + self.args = args + + def run(self): + thread = threading.Thread(target=self.target, args=self.args) + thread.daemon = True + thread.start() diff --git a/scraper/webhook.py b/scraper/webhook.py index 9c40edd..1767e3b 100644 --- a/scraper/webhook.py +++ b/scraper/webhook.py @@ -1,21 +1,27 @@ import requests +from app_log import LoggingManager class WebhookNotifier: def __init__(self, webhook_url, disable_webhook=False): self.webhook_url = webhook_url self.disable_webhook = disable_webhook + self.log_manager = LoggingManager("scraper.log") - def send_notification(self, submission): - title = submission.title - url = submission.url - permalink = submission.permalink - selftext = submission.selftext + def send_notification(self, post): + title = post.title + url = post.url + permalink = post.permalink + selftext = post.selftext content = f""" - **New Deal!** - **Title:** {title} - **URL:** {url} - **Permalink:** https://old.reddit.com{permalink} - **Selftext:** {selftext}""" +**New Deal!** +**Title:** {title} +**URL:** {url} +**Permalink:** https://old.reddit.com{permalink} +**Selftext:** {selftext}""" if not self.disable_webhook: - requests.post(self.webhook_url, data={"content": content}) \ No newline at end of file + self.log_manager.log(f"Sending notification to {self.webhook_url}") + try: + requests.post(self.webhook_url, data={"content": content}) + except Exception as e: + self.log_manager.error(f"Failed to send notification: {e}") \ No newline at end of file diff --git a/server/pokemans_app/migrations/0001_initial.py b/server/pokemans_app/migrations/0001_initial.py index d5ae54a..89634ab 100644 --- a/server/pokemans_app/migrations/0001_initial.py +++ b/server/pokemans_app/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 5.0.2 on 2024-03-04 01:40 +# Generated by Django 5.0.2 on 2024-03-04 05:15 import django.db.models.deletion from django.db import migrations, models @@ -8,38 +8,33 @@ class Migration(migrations.Migration): initial = True - dependencies = [] + dependencies = [ + ] operations = [ migrations.CreateModel( - name="Submission", + name='Post', fields=[ - ("id", models.AutoField(primary_key=True, serialize=False)), - ("reddit_id", models.CharField(max_length=255, unique=True)), - ("title", models.CharField(max_length=255)), - ("name", models.CharField(max_length=255)), - ("url", models.CharField(max_length=255)), - ("created_utc", models.FloatField()), - ("selftext", models.CharField(max_length=255)), - ("permalink", models.CharField(max_length=255)), - ("upvote_ratio", models.FloatField()), - ("updated_at", models.DateTimeField(auto_now=True)), + ('id', models.AutoField(primary_key=True, serialize=False)), + ('reddit_id', models.CharField(max_length=255, unique=True)), + ('title', models.CharField(max_length=255)), + ('name', models.CharField(max_length=255)), + ('url', models.CharField(max_length=555)), + ('created_utc', models.FloatField()), + ('selftext', models.CharField(blank=True, max_length=2555, null=True)), + ('permalink', models.CharField(max_length=255)), + ('updated_at', models.DateTimeField(auto_now=True)), ], ), migrations.CreateModel( - name="SubmissionAnalytics", + name='PostAnalytics', fields=[ - ("id", models.AutoField(primary_key=True, serialize=False)), - ("num_comments", models.IntegerField()), - ("score", models.IntegerField()), - ("created_at", models.DateTimeField(auto_now=True)), - ( - "submission", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - to="pokemans_app.submission", - ), - ), + ('id', models.AutoField(primary_key=True, serialize=False)), + ('num_comments', models.IntegerField()), + ('score', models.IntegerField()), + ('upvote_ratio', models.FloatField()), + ('created_at', models.DateTimeField(auto_now=True)), + ('post', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='pokemans_app.post')), ], ), ] diff --git a/server/pokemans_app/migrations/0002_alter_submission_selftext.py b/server/pokemans_app/migrations/0002_alter_submission_selftext.py deleted file mode 100644 index 7cf72a3..0000000 --- a/server/pokemans_app/migrations/0002_alter_submission_selftext.py +++ /dev/null @@ -1,18 +0,0 @@ -# Generated by Django 5.0.2 on 2024-03-04 03:51 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("pokemans_app", "0001_initial"), - ] - - operations = [ - migrations.AlterField( - model_name="submission", - name="selftext", - field=models.CharField(blank=True, max_length=1234), - ), - ] diff --git a/server/pokemans_app/models.py b/server/pokemans_app/models.py index 814d33a..a9566e6 100644 --- a/server/pokemans_app/models.py +++ b/server/pokemans_app/models.py @@ -1,22 +1,22 @@ from django.db import models -class Submission(models.Model): +class Post(models.Model): id = models.AutoField(primary_key=True) reddit_id = models.CharField(max_length=255, unique=True) title = models.CharField(max_length=255) name = models.CharField(max_length=255) - url = models.CharField(max_length=255) + url = models.CharField(max_length=555) created_utc = models.FloatField() - selftext = models.CharField(max_length=1234, blank=True) + selftext = models.CharField(max_length=2555, blank=True, null=True) permalink = models.CharField(max_length=255) - upvote_ratio = models.FloatField() updated_at = models.DateTimeField(auto_now=True) -class SubmissionAnalytics(models.Model): +class PostAnalytics(models.Model): id = models.AutoField(primary_key=True) - submission = models.ForeignKey(Submission, on_delete=models.CASCADE) + post = models.ForeignKey(Post, on_delete=models.CASCADE) num_comments = models.IntegerField() score = models.IntegerField() + upvote_ratio = models.FloatField() created_at = models.DateTimeField(auto_now=True) \ No newline at end of file diff --git a/server/pokemans_app/serializers.py b/server/pokemans_app/serializers.py index 51c66e9..63bf48d 100644 --- a/server/pokemans_app/serializers.py +++ b/server/pokemans_app/serializers.py @@ -1,13 +1,13 @@ from rest_framework import serializers -from .models import Submission, SubmissionAnalytics +from .models import Post, PostAnalytics -class SubmissionSerializer(serializers.ModelSerializer): +class PostSerializer(serializers.ModelSerializer): class Meta: - model = Submission + model = Post fields = '__all__' -class SubmissionAnalyticsSerializer(serializers.ModelSerializer): +class PostAnalyticsSerializer(serializers.ModelSerializer): class Meta: - model = SubmissionAnalytics + model = PostAnalytics fields = '__all__' \ No newline at end of file diff --git a/server/pokemans_app/views.py b/server/pokemans_app/views.py index affe68a..2aee3d5 100644 --- a/server/pokemans_app/views.py +++ b/server/pokemans_app/views.py @@ -1,17 +1,18 @@ from django.shortcuts import render from rest_framework import viewsets -from .models import Submission, SubmissionAnalytics -from .serializers import SubmissionSerializer, SubmissionAnalyticsSerializer +from .models import Post, PostAnalytics +from .serializers import PostSerializer, PostAnalyticsSerializer from datetime import timedelta from django.utils import timezone +from django.utils.dateparse import parse_datetime -class SubmissionViewSet(viewsets.ModelViewSet): - queryset = Submission.objects.all() - serializer_class = SubmissionSerializer +class PostViewSet(viewsets.ModelViewSet): + queryset = Post.objects.all() + serializer_class = PostSerializer def get_queryset(self): - queryset = Submission.objects.all() + queryset = Post.objects.all() reddit_id = self.request.query_params.get('reddit_id', None) last_7_days = self.request.query_params.get('last_7_days', None) @@ -27,6 +28,30 @@ class SubmissionViewSet(viewsets.ModelViewSet): return queryset -class SubmissionAnalyticsViewSet(viewsets.ModelViewSet): - queryset = SubmissionAnalytics.objects.all() - serializer_class = SubmissionAnalyticsSerializer \ No newline at end of file +class PostAnalyticsViewSet(viewsets.ModelViewSet): + queryset = PostAnalytics.objects.all() + serializer_class = PostAnalyticsSerializer + + def get_queryset(self): + queryset = PostAnalytics.objects.all() + post_id = self.request.query_params.get('post', None) + time_begin = self.request.query_params.get('time_begin', None) + time_end = self.request.query_params.get('time_end', None) + + if post_id is not None: + queryset = queryset.filter(post=post_id) + + if time_begin is not None and time_end is not None: + # Parse the datetime strings to timezone-aware datetime objects + time_begin_parsed = parse_datetime(time_begin) + time_end_parsed = parse_datetime(time_end) + + # Ensure datetime objects are timezone-aware + if time_begin_parsed is not None and time_end_parsed is not None: + queryset = queryset.filter(created_at__gte=time_begin_parsed, created_at__lte=time_end_parsed) + else: + # Handle invalid datetime format + # This is where you could log an error or handle the case where datetime strings are invalid + pass + + return queryset \ No newline at end of file diff --git a/server/pokemans_django/urls.py b/server/pokemans_django/urls.py index ac2d4af..ef93e59 100644 --- a/server/pokemans_django/urls.py +++ b/server/pokemans_django/urls.py @@ -17,12 +17,12 @@ Including another URLconf from django.contrib import admin from django.urls import path, include from rest_framework.routers import DefaultRouter -from pokemans_app.views import SubmissionViewSet, SubmissionAnalyticsViewSet +from pokemans_app.views import PostViewSet, PostAnalyticsViewSet router = DefaultRouter() -router.register(r"submissions", SubmissionViewSet) -router.register(r"submission_analytics", SubmissionAnalyticsViewSet) +router.register(r"posts", PostViewSet) +router.register(r"post_analytics", PostAnalyticsViewSet) urlpatterns = [ path("admin/", admin.site.urls),