kemono2/src/lib/dms.py
2024-07-04 22:08:17 +02:00

155 lines
5.3 KiB
Python

import base64
from typing import List
from src.internals.database.database import cached_count, cached_query, query_db, query_rowcount_db
from src.internals.serializers.dm import deserialize_dms, serialize_dms
from src.lib.artist import nload_query_artists
from src.types.kemono import Approved_DM, Unapproved_DM
DM_FIELDS_LIST = ['"hash"', '"user"', "service", "content", "embed", "added", "published", "file"]
DM_FIELDS = ", ".join(DM_FIELDS_LIST)
def get_unapproved_dms(account_id: int, deleted=False) -> List[Unapproved_DM]:
query = f"""
SELECT {DM_FIELDS}, contributor_id, import_id
FROM unapproved_dms
WHERE contributor_id = %s and deleted_at is {"NOT" if deleted else ""} NULL
"""
result = query_db(query, (str(account_id),))
creator_dict = nload_query_artists([(each["service"], each["user"]) for each in result], True)
for dm in result:
dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {})
return [Unapproved_DM.init_from_dict(dm) for dm in result]
def has_unapproved_dms(account_id: int, deleted=False) -> bool:
query = f"""
SELECT true
FROM unapproved_dms
WHERE contributor_id = %s and deleted_at is {"NOT" if deleted else ""} NULL
LIMIT 1
"""
result = query_db(query, (str(account_id),))
return bool(result)
def count_user_dms(service: str, user_id: str, reload: bool = False) -> int:
if service not in ("patreon",):
return 0
key = f"dms_count:{service}:{user_id}"
query = 'SELECT COUNT(*) FROM dms WHERE service = %s AND "user" = %s'
return cached_count(query, key, (service, user_id), reload)
def get_artist_dms(service: str, artist_id: str, reload: bool = False) -> List[Approved_DM]:
key = f"dms:{service}:{artist_id}"
query = f"""
SELECT {DM_FIELDS}
FROM dms
WHERE service = %s AND "user" = %s
"""
result = cached_query(query, key, (service, artist_id), serialize_dms, deserialize_dms, reload)
return [Approved_DM.init_from_dict(dm) for dm in result]
def get_all_dms_count(reload: bool = False) -> int:
key = "all_dms_count"
query = "SELECT COUNT(*) FROM dms"
return cached_count(query, key, (), reload, lock_enabled=True)
def get_all_dms(offset: int, limit: int, reload: bool = False) -> List[Approved_DM]:
key = f"all_dms:{offset}"
query = f"""
SELECT {DM_FIELDS}
FROM dms
ORDER BY added DESC
OFFSET %s
LIMIT %s
"""
results = cached_query(
query, key, (offset, limit), serialize_dms, deserialize_dms, reload, lock_enabled=True
) # maybe not lock?
creator_dict = nload_query_artists([(each["service"], each["user"]) for each in results], True)
for dm in results:
dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {})
return [Approved_DM.init_from_dict(dm) for dm in results]
def get_all_dms_by_query_count(text_query: str, reload: bool = False) -> int:
query = """
SELECT COUNT(*) FROM dms WHERE content &@~ %s
"""
key = f"all_dms_by_query_count:{base64.b64encode(text_query.encode()).decode()}"
return cached_count(query, key, (text_query,), reload, lock_enabled=True)
def get_all_dms_by_query(text_query: str, offset: int, limit: int, reload: bool = False) -> List[Approved_DM]:
key = f"all_dms_by_query:{offset}:{limit}:{base64.b64encode(text_query.encode()).decode()}"
query = f"""
SELECT {DM_FIELDS}
FROM dms
WHERE content &@~ %s
ORDER BY added DESC
OFFSET %s
LIMIT %s
"""
results = cached_query(
query, key, (text_query, offset, limit), serialize_dms, deserialize_dms, reload, lock_enabled=True
)
creator_dict = nload_query_artists([(each["service"], each["user"]) for each in results], True)
for dm in results:
dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {})
return [Approved_DM.init_from_dict(dm) for dm in results]
def cleanup_unapproved_dms(contributor_id: int, delete=False) -> int:
if delete:
return query_rowcount_db(
f"""
DELETE
FROM unapproved_dms
WHERE contributor_id = %s and deleted_at is NOT NULL
""",
(str(contributor_id),),
)
else:
return query_rowcount_db(
f"""
UPDATE unapproved_dms
SET deleted_at = CURRENT_TIMESTAMP
WHERE contributor_id = %s and deleted_at is NULL
""",
(str(contributor_id),),
)
def clean_dms_already_approved(contributor_id: int | None = None):
return query_rowcount_db(
f"""
DELETE FROM public.unapproved_dms
USING public.dms
WHERE public.unapproved_dms.hash = public.dms.hash {" AND contributor_id = %s" if contributor_id else ""};
""",
(str(contributor_id),) if contributor_id else tuple(),
)
def approve_dms(contributor_id: int, dm_hashes: list[str]):
insert_fields = DM_FIELDS.replace(", added", "")
query = f"""
INSERT INTO dms ({insert_fields})
SELECT {insert_fields}
FROM unapproved_dms
WHERE contributor_id = %s AND hash = ANY(%s)
ON CONFLICT ("hash","user", service) DO NOTHING;
"""
query_rowcount_db(query, (str(contributor_id), dm_hashes))