import base64 import itertools from dataclasses import dataclass from datetime import datetime, timedelta from typing import Optional, TypedDict, Any from src.config import Configuration from src.internals.cache.redis import get_conn from src.internals.database.database import cached_count, cached_query from src.internals.serializers.generic_with_dates import deserialize_dict_list, serialize_dict_list from src.internals.serializers.post import deserialize_post_list, serialize_post_list from src.utils.datetime_ import PeriodScale from src.utils.utils import batched class Post(TypedDict): id: str user: str service: str title: str content: str embed: dict shared_file: bool added: datetime published: datetime edited: datetime file: dict attachments: list[dict] poll: dict captions: Any tags: list[str] incomplete_rewards: Optional[str] class PostWithFavCount(Post): fav_count: int POST_FLAG_REASON_NUMBER_TO_SLUG = { -2: "delete-copyright", -1: "delete-abuse", 1: "missing-password", 2: "offsite-expired", 10: "post-changed", 20: "corrupted-files", 21: "missing-files", 11: "stale-comments", 12: "formatting-error", 8: "reason-other", } POST_FLAG_REASON_SLUG_TO_NUMBER = {v:k for k,v in POST_FLAG_REASON_NUMBER_TO_SLUG.items()} POST_FLAG_CUT_OFF = 0 def count_all_posts(reload=False) -> int: key = "global_post_count" query = 'SELECT COUNT(*) FROM posts WHERE ("user", service) NOT IN (SELECT id, service from dnp)' return cached_count(query, key, reload=reload, ex=6000, lock_enabled=True) def count_all_posts_for_query(q: str, reload=False) -> int: q = " OR ".join(x.lower() for x in q.strip().split(" OR ")) if q == "": return count_all_posts(reload=reload) key = f"global_post_count_for_query:{base64.b64encode(q.encode()).decode()}" query = """ BEGIN; SET LOCAL random_page_cost = 0.0001; SET LOCAL statement_timeout = 10000; SELECT COUNT(*) FROM posts WHERE (title || ' ' || content) &@~ %s AND ("user", service) NOT IN ( SELECT id, service FROM dnp ); COMMIT; """ return cached_count(query, key, (q,), reload, prepare=False, client_bind=True, sets_to_fetch=[3], lock_enabled=True) def count_all_posts_for_tag(tags: list[str], service: Optional[str] = None, artist_id: Optional[str] = None) -> int: b = base64.b64encode(f"==TAG==\0{tags}".encode()).decode() key = f"global_post_count_for_query:{b}" query = """ SELECT COUNT(*) FROM POSTS WHERE "tags" @> %s::citext[] """ params = (tags,) if service and artist_id: query += """ AND "service" = %s AND "user" = %s """ params += (service, artist_id) return cached_count(query, key, params) def get_all_posts_summary(offset: int, limit=50, reload=False, cache_ttl=None): # we need this version to reduce redis size and bandwidth in half key = f"all_posts:summary:{limit}:{offset}" query = """ SELECT id, "user", service, title, substring("content", 1, 50), published, file, attachments FROM posts WHERE ("user", service) NOT IN ( SELECT id, service from dnp ) ORDER BY added DESC OFFSET %s LIMIT %s """ extra = {} if cache_ttl: extra["ex"] = cache_ttl return cached_query( query, key, (offset, limit), serialize_dict_list, deserialize_dict_list, reload, lock_enabled=True, **extra ) def get_all_posts_full(offset: int, limit=50, reload=False): key = f"all_posts:full:{limit}:{offset}" query = """ SELECT id, "user", service, title, content, embed, shared_file, ( CASE service WHEN 'fanbox' THEN NULL ELSE added END ) AS added, published, edited, file, attachments, poll, captions, tags FROM posts WHERE ("user", service) NOT IN ( SELECT id, service FROM dnp ) ORDER BY added DESC OFFSET %s LIMIT %s """ return cached_query( query, key, (offset, limit), serialize_dict_list, deserialize_dict_list, reload, lock_enabled=True ) def get_all_posts_for_query(q: str, offset: int, limit=50, reload=False): q = " OR ".join(x.lower() for x in q.strip().split(" OR ")) if q == "": return get_all_posts_summary(0, limit, reload, cache_ttl=Configuration().cache_ttl_for_recent_posts) key = f"all_posts_for_query:{base64.b64encode(q.encode()).decode()}:{limit}:{offset}" query = """ BEGIN; SET LOCAL random_page_cost = 0.0001; SET LOCAL statement_timeout = 10000; SELECT id, "user", service, title, substring("content", 1, 50), published, file, attachments FROM posts WHERE (title || ' ' || content) &@~ %s AND ("user", service) NOT IN ( SELECT id, service FROM dnp ) ORDER BY added DESC LIMIT %s OFFSET %s; COMMIT; """ return cached_query( query, key, (q, limit, offset), serialize_dict_list, deserialize_dict_list, reload, prepare=False, client_bind=True, sets_to_fetch=[3], lock_enabled=True, ) def get_all_channels_for_server(discord_server, reload=False): key = f"discord_channels_for_server:{discord_server}" query = "SELECT channel_id as id, name FROM discord_channels WHERE server_id = %s" return cached_query(query, key, (discord_server,), reload=reload, ex_on_null=60, lock_enabled=True) def get_popular_posts_for_date_range( start_date: datetime, end_date: datetime, scale: PeriodScale, page: int, per_page: int, pages_to_query: int, expiry: int = Configuration().redis["default_ttl"], reload: bool = False, ) -> list[PostWithFavCount]: key = f"popular_posts:{scale}:{per_page}:{start_date.isoformat()}-{end_date.isoformat()}" redis = get_conn() result = redis.lindex(key, page) if result: parsed_result = deserialize_post_list(result) if parsed_result: return parsed_result else: return [] else: if page != 0: result = redis.lindex(key, 0) if result: return [] params = (start_date, end_date, pages_to_query * per_page) order_factor = "COUNT(*)" if scale == "recent": order_factor = 'SUM((EXTRACT(EPOCH FROM ("created_at" - %s )) / EXTRACT(EPOCH FROM ( %s - %s )) ))::float' params = (start_date, end_date, start_date, *params) query = f""" WITH "top_faves" AS ( SELECT "service", "post_id", { order_factor } as fav_count FROM "account_post_favorite" WHERE "created_at" BETWEEN %s AND %s GROUP BY "service", "post_id" ORDER BY fav_count DESC LIMIT %s ) SELECT p.id, p."user", p.service, p.title, substring( p."content", 1, 50), p.published, p.file, p.attachments, tf."fav_count" FROM "top_faves" AS tf INNER JOIN "posts" AS p ON p."id" = tf."post_id" AND p."service" = tf."service"; """ result = cached_query( query, key, params, serialize_fn=lambda x: [serialize_post_list(cache_page) for cache_page in batched(x, per_page)], deserialize_fn=lambda x: list(itertools.chain(*(deserialize_post_list(cache_page) for cache_page in x))), ex=expiry, reload=reload, cache_store_method="rpush", lock_enabled=True, ) return (result or [])[(page * per_page) : ((page + 1) * per_page)] def get_tagged_posts( tags: list[str], offset: int, limit: int, service: Optional[str] = None, artist_id: Optional[str] = None ) -> list[Post]: key = f"tagged_posts:{tags}:{service}:{artist_id}:{offset}" query = """ SELECT id, "user", service, title, content, embed, shared_file, ( CASE service WHEN 'fanbox' THEN NULL ELSE added END ) AS added, published, edited, file, attachments, poll, captions, tags FROM "posts" WHERE "tags" @> %s::citext[] """ params: tuple[...] = (tags,) if service and artist_id: query += """ AND "service" = %s AND "user" = %s ORDER BY published DESC """ params += (service, artist_id) else: query += " ORDER BY added DESC " query += "OFFSET %s LIMIT %s" params += (str(offset), str(limit)) return cached_query(query, key, params) @dataclass class Tag: tag: str post_count: int def get_all_tags(service: Optional[str] = None, creator_id: Optional[str] = None) -> list[Tag]: if creator_id and not service: raise Exception("Must be used with both creator_id and service") key = f"tags:{service or ""}:{creator_id or ""}" query = f""" SELECT {"tag" if creator_id else "lower(tag)"} as tag, COUNT(1) AS post_count FROM "posts" CROSS JOIN UNNEST(tags) AS tag """ params: tuple[str, ...] = () if service and creator_id: query += """WHERE "service" = %s AND "user" = %s """ params += (service, creator_id) query += """ GROUP BY tag ORDER BY post_count DESC LIMIT 2000 """ ex = int(timedelta(hours=(6 if creator_id else 24)).total_seconds()) return cached_query(query, key, params, ex=ex)