import base64 from typing import List from src.internals.database.database import cached_count, cached_query, query_db, query_rowcount_db from src.internals.serializers.dm import deserialize_dms, serialize_dms from src.lib.artist import nload_query_artists from src.types.kemono import Approved_DM, Unapproved_DM DM_FIELDS_LIST = ['"hash"', '"user"', "service", "content", "embed", "added", "published", "file"] DM_FIELDS = ", ".join(DM_FIELDS_LIST) def get_unapproved_dms(account_id: int, deleted=False) -> List[Unapproved_DM]: query = f""" SELECT {DM_FIELDS}, contributor_id, import_id FROM unapproved_dms WHERE contributor_id = %s and deleted_at is {"NOT" if deleted else ""} NULL """ result = query_db(query, (str(account_id),)) creator_dict = nload_query_artists([(each["service"], each["user"]) for each in result], True) for dm in result: dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {}) return [Unapproved_DM.init_from_dict(dm) for dm in result] def has_unapproved_dms(account_id: int, deleted=False) -> bool: query = f""" SELECT true FROM unapproved_dms WHERE contributor_id = %s and deleted_at is {"NOT" if deleted else ""} NULL LIMIT 1 """ result = query_db(query, (str(account_id),)) return bool(result) def count_user_dms(service: str, user_id: str, reload: bool = False) -> int: if service not in ("patreon",): return 0 key = f"dms_count:{service}:{user_id}" query = 'SELECT COUNT(*) FROM dms WHERE service = %s AND "user" = %s' return cached_count(query, key, (service, user_id), reload) def get_artist_dms(service: str, artist_id: str, reload: bool = False) -> List[Approved_DM]: key = f"dms:{service}:{artist_id}" query = f""" SELECT {DM_FIELDS} FROM dms WHERE service = %s AND "user" = %s """ result = cached_query(query, key, (service, artist_id), serialize_dms, deserialize_dms, reload) return [Approved_DM.init_from_dict(dm) for dm in result] def get_all_dms_count(reload: bool = False) -> int: key = "all_dms_count" query = "SELECT COUNT(*) FROM dms" return cached_count(query, key, (), reload, lock_enabled=True) def get_all_dms(offset: int, limit: int, reload: bool = False) -> List[Approved_DM]: key = f"all_dms:{offset}" query = f""" SELECT {DM_FIELDS} FROM dms ORDER BY added DESC OFFSET %s LIMIT %s """ results = cached_query( query, key, (offset, limit), serialize_dms, deserialize_dms, reload, lock_enabled=True ) # maybe not lock? creator_dict = nload_query_artists([(each["service"], each["user"]) for each in results], True) for dm in results: dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {}) return [Approved_DM.init_from_dict(dm) for dm in results] def get_all_dms_by_query_count(text_query: str, reload: bool = False) -> int: query = """ SELECT COUNT(*) FROM dms WHERE content &@~ %s """ key = f"all_dms_by_query_count:{base64.b64encode(text_query.encode()).decode()}" return cached_count(query, key, (text_query,), reload, lock_enabled=True) def get_all_dms_by_query(text_query: str, offset: int, limit: int, reload: bool = False) -> List[Approved_DM]: key = f"all_dms_by_query:{offset}:{limit}:{base64.b64encode(text_query.encode()).decode()}" query = f""" SELECT {DM_FIELDS} FROM dms WHERE content &@~ %s ORDER BY added DESC OFFSET %s LIMIT %s """ results = cached_query( query, key, (text_query, offset, limit), serialize_dms, deserialize_dms, reload, lock_enabled=True ) creator_dict = nload_query_artists([(each["service"], each["user"]) for each in results], True) for dm in results: dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {}) return [Approved_DM.init_from_dict(dm) for dm in results] def cleanup_unapproved_dms(contributor_id: int, delete=False) -> int: if delete: return query_rowcount_db( f""" DELETE FROM unapproved_dms WHERE contributor_id = %s and deleted_at is NOT NULL """, (str(contributor_id),), ) else: return query_rowcount_db( f""" UPDATE unapproved_dms SET deleted_at = CURRENT_TIMESTAMP WHERE contributor_id = %s and deleted_at is NULL """, (str(contributor_id),), ) def clean_dms_already_approved(contributor_id: int | None = None): return query_rowcount_db( f""" DELETE FROM public.unapproved_dms USING public.dms WHERE public.unapproved_dms.hash = public.dms.hash {" AND contributor_id = %s" if contributor_id else ""}; """, (str(contributor_id),) if contributor_id else tuple(), ) def approve_dms(contributor_id: int, dm_hashes: list[str]): insert_fields = DM_FIELDS.replace(", added", "") query = f""" INSERT INTO dms ({insert_fields}) SELECT {insert_fields} FROM unapproved_dms WHERE contributor_id = %s AND hash = ANY(%s) ON CONFLICT ("hash","user", service) DO NOTHING; """ query_rowcount_db(query, (str(contributor_id), dm_hashes))