155 lines
5.3 KiB
Python
155 lines
5.3 KiB
Python
import base64
|
|
from typing import List
|
|
|
|
from src.internals.database.database import cached_count, cached_query, query_db, query_rowcount_db
|
|
from src.internals.serializers.dm import deserialize_dms, serialize_dms
|
|
from src.lib.artist import nload_query_artists
|
|
from src.types.kemono import Approved_DM, Unapproved_DM
|
|
|
|
DM_FIELDS_LIST = ['"hash"', '"user"', "service", "content", "embed", "added", "published", "file"]
|
|
DM_FIELDS = ", ".join(DM_FIELDS_LIST)
|
|
|
|
|
|
def get_unapproved_dms(account_id: int, deleted=False) -> List[Unapproved_DM]:
|
|
query = f"""
|
|
SELECT {DM_FIELDS}, contributor_id, import_id
|
|
FROM unapproved_dms
|
|
WHERE contributor_id = %s and deleted_at is {"NOT" if deleted else ""} NULL
|
|
"""
|
|
result = query_db(query, (str(account_id),))
|
|
|
|
creator_dict = nload_query_artists([(each["service"], each["user"]) for each in result], True)
|
|
for dm in result:
|
|
dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {})
|
|
|
|
return [Unapproved_DM.init_from_dict(dm) for dm in result]
|
|
|
|
|
|
def has_unapproved_dms(account_id: int, deleted=False) -> bool:
|
|
query = f"""
|
|
SELECT true
|
|
FROM unapproved_dms
|
|
WHERE contributor_id = %s and deleted_at is {"NOT" if deleted else ""} NULL
|
|
LIMIT 1
|
|
"""
|
|
result = query_db(query, (str(account_id),))
|
|
return bool(result)
|
|
|
|
|
|
def count_user_dms(service: str, user_id: str, reload: bool = False) -> int:
|
|
if service not in ("patreon",):
|
|
return 0
|
|
key = f"dms_count:{service}:{user_id}"
|
|
query = 'SELECT COUNT(*) FROM dms WHERE service = %s AND "user" = %s'
|
|
return cached_count(query, key, (service, user_id), reload)
|
|
|
|
|
|
def get_artist_dms(service: str, artist_id: str, reload: bool = False) -> List[Approved_DM]:
|
|
key = f"dms:{service}:{artist_id}"
|
|
query = f"""
|
|
SELECT {DM_FIELDS}
|
|
FROM dms
|
|
WHERE service = %s AND "user" = %s
|
|
"""
|
|
result = cached_query(query, key, (service, artist_id), serialize_dms, deserialize_dms, reload)
|
|
return [Approved_DM.init_from_dict(dm) for dm in result]
|
|
|
|
|
|
def get_all_dms_count(reload: bool = False) -> int:
|
|
key = "all_dms_count"
|
|
query = "SELECT COUNT(*) FROM dms"
|
|
return cached_count(query, key, (), reload, lock_enabled=True)
|
|
|
|
|
|
def get_all_dms(offset: int, limit: int, reload: bool = False) -> List[Approved_DM]:
|
|
key = f"all_dms:{offset}"
|
|
query = f"""
|
|
SELECT {DM_FIELDS}
|
|
FROM dms
|
|
ORDER BY added DESC
|
|
OFFSET %s
|
|
LIMIT %s
|
|
"""
|
|
results = cached_query(
|
|
query, key, (offset, limit), serialize_dms, deserialize_dms, reload, lock_enabled=True
|
|
) # maybe not lock?
|
|
|
|
creator_dict = nload_query_artists([(each["service"], each["user"]) for each in results], True)
|
|
for dm in results:
|
|
dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {})
|
|
|
|
return [Approved_DM.init_from_dict(dm) for dm in results]
|
|
|
|
|
|
def get_all_dms_by_query_count(text_query: str, reload: bool = False) -> int:
|
|
query = """
|
|
SELECT COUNT(*) FROM dms WHERE content &@~ %s
|
|
"""
|
|
key = f"all_dms_by_query_count:{base64.b64encode(text_query.encode()).decode()}"
|
|
return cached_count(query, key, (text_query,), reload, lock_enabled=True)
|
|
|
|
|
|
def get_all_dms_by_query(text_query: str, offset: int, limit: int, reload: bool = False) -> List[Approved_DM]:
|
|
key = f"all_dms_by_query:{offset}:{limit}:{base64.b64encode(text_query.encode()).decode()}"
|
|
query = f"""
|
|
SELECT {DM_FIELDS}
|
|
FROM dms
|
|
WHERE content &@~ %s
|
|
ORDER BY added DESC
|
|
OFFSET %s
|
|
LIMIT %s
|
|
"""
|
|
results = cached_query(
|
|
query, key, (text_query, offset, limit), serialize_dms, deserialize_dms, reload, lock_enabled=True
|
|
)
|
|
|
|
creator_dict = nload_query_artists([(each["service"], each["user"]) for each in results], True)
|
|
for dm in results:
|
|
dm["artist"] = (creator_dict.get((dm["service"], dm["user"])) or {})
|
|
|
|
return [Approved_DM.init_from_dict(dm) for dm in results]
|
|
|
|
|
|
def cleanup_unapproved_dms(contributor_id: int, delete=False) -> int:
|
|
if delete:
|
|
return query_rowcount_db(
|
|
f"""
|
|
DELETE
|
|
FROM unapproved_dms
|
|
WHERE contributor_id = %s and deleted_at is NOT NULL
|
|
""",
|
|
(str(contributor_id),),
|
|
)
|
|
else:
|
|
return query_rowcount_db(
|
|
f"""
|
|
UPDATE unapproved_dms
|
|
SET deleted_at = CURRENT_TIMESTAMP
|
|
WHERE contributor_id = %s and deleted_at is NULL
|
|
""",
|
|
(str(contributor_id),),
|
|
)
|
|
|
|
|
|
def clean_dms_already_approved(contributor_id: int | None = None):
|
|
return query_rowcount_db(
|
|
f"""
|
|
DELETE FROM public.unapproved_dms
|
|
USING public.dms
|
|
WHERE public.unapproved_dms.hash = public.dms.hash {" AND contributor_id = %s" if contributor_id else ""};
|
|
""",
|
|
(str(contributor_id),) if contributor_id else tuple(),
|
|
)
|
|
|
|
|
|
def approve_dms(contributor_id: int, dm_hashes: list[str]):
|
|
insert_fields = DM_FIELDS.replace(", added", "")
|
|
query = f"""
|
|
INSERT INTO dms ({insert_fields})
|
|
SELECT {insert_fields}
|
|
FROM unapproved_dms
|
|
WHERE contributor_id = %s AND hash = ANY(%s)
|
|
ON CONFLICT ("hash","user", service) DO NOTHING;
|
|
"""
|
|
query_rowcount_db(query, (str(contributor_id), dm_hashes))
|