kemono2/src/lib/posts.py
2025-04-11 00:58:59 +02:00

403 lines
10 KiB
Python

import base64
import itertools
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, TypedDict, Any
from src.config import Configuration
from src.internals.cache.redis import get_conn
from src.internals.database.database import cached_count, cached_query
from src.internals.serializers.generic_with_dates import deserialize_dict_list, serialize_dict_list
from src.internals.serializers.post import deserialize_post_list, serialize_post_list
from src.utils.datetime_ import PeriodScale
from src.utils.utils import batched
class Post(TypedDict):
id: str
user: str
service: str
title: str
content: str
embed: dict
shared_file: bool
added: datetime
published: datetime
edited: datetime
file: dict
attachments: list[dict]
poll: dict
captions: Any
tags: list[str]
incomplete_rewards: Optional[str]
class PostWithFavCount(Post):
fav_count: int
POST_FLAG_REASON_NUMBER_TO_SLUG = {
-2: "delete-copyright",
-1: "delete-abuse",
1: "missing-password",
2: "offsite-expired",
10: "post-changed",
20: "corrupted-files",
21: "missing-files",
11: "stale-comments",
12: "formatting-error",
8: "reason-other",
}
POST_FLAG_REASON_SLUG_TO_NUMBER = {v:k for k,v in POST_FLAG_REASON_NUMBER_TO_SLUG.items()}
POST_FLAG_CUT_OFF = 0
def count_all_posts(reload=False) -> int:
key = "global_post_count"
query = 'SELECT COUNT(*) FROM posts WHERE ("user", service) NOT IN (SELECT id, service from dnp)'
return cached_count(query, key, reload=reload, ex=6000, lock_enabled=True)
def count_all_posts_for_query(q: str, reload=False) -> int:
q = " OR ".join(x.lower() for x in q.strip().split(" OR "))
if q == "":
return count_all_posts(reload=reload)
key = f"global_post_count_for_query:{base64.b64encode(q.encode()).decode()}"
query = """
BEGIN;
SET LOCAL random_page_cost = 0.0001;
SET LOCAL statement_timeout = 10000;
SELECT COUNT(*)
FROM posts
WHERE (title || ' ' || content) &@~ %s
AND ("user", service) NOT IN (
SELECT id, service
FROM dnp
);
COMMIT;
"""
return cached_count(query, key, (q,), reload, prepare=False, client_bind=True, sets_to_fetch=[3], lock_enabled=True)
def count_all_posts_for_tag(tags: list[str], service: Optional[str] = None, artist_id: Optional[str] = None) -> int:
b = base64.b64encode(f"==TAG==\0{tags}".encode()).decode()
key = f"global_post_count_for_query:{b}"
query = """
SELECT
COUNT(*)
FROM
POSTS
WHERE
"tags" @> %s::citext[]
"""
params = (tags,)
if service and artist_id:
query += """
AND "service" = %s AND "user" = %s
"""
params += (service, artist_id)
return cached_count(query, key, params)
def get_all_posts_summary(offset: int, limit=50, reload=False, cache_ttl=None):
# we need this version to reduce redis size and bandwidth in half
key = f"all_posts:summary:{limit}:{offset}"
query = """
SELECT
id,
"user",
service,
title,
substring("content", 1, 50),
published,
file,
attachments
FROM
posts
WHERE
("user", service) NOT IN (
SELECT id, service from dnp
)
ORDER BY
added DESC
OFFSET %s
LIMIT %s
"""
extra = {}
if cache_ttl:
extra["ex"] = cache_ttl
return cached_query(
query, key, (offset, limit), serialize_dict_list, deserialize_dict_list, reload, lock_enabled=True, **extra
)
def get_all_posts_full(offset: int, limit=50, reload=False):
key = f"all_posts:full:{limit}:{offset}"
query = """
SELECT
id,
"user",
service,
title,
content,
embed,
shared_file,
(
CASE service
WHEN 'fanbox'
THEN NULL
ELSE added
END
) AS added,
published,
edited,
file,
attachments,
poll,
captions,
tags
FROM
posts
WHERE
("user", service) NOT IN (
SELECT
id,
service
FROM
dnp
)
ORDER BY
added DESC
OFFSET %s
LIMIT %s
"""
return cached_query(
query, key, (offset, limit), serialize_dict_list, deserialize_dict_list, reload, lock_enabled=True
)
def get_all_posts_for_query(q: str, offset: int, limit=50, reload=False):
q = " OR ".join(x.lower() for x in q.strip().split(" OR "))
if q == "":
return get_all_posts_summary(0, limit, reload, cache_ttl=Configuration().cache_ttl_for_recent_posts)
key = f"all_posts_for_query:{base64.b64encode(q.encode()).decode()}:{limit}:{offset}"
query = """
BEGIN;
SET LOCAL random_page_cost = 0.0001;
SET LOCAL statement_timeout = 10000;
SELECT
id,
"user",
service,
title,
substring("content", 1, 50),
published,
file,
attachments
FROM
posts
WHERE
(title || ' ' || content) &@~ %s
AND
("user", service) NOT IN (
SELECT id, service
FROM dnp
)
ORDER BY
added DESC
LIMIT %s
OFFSET %s;
COMMIT;
"""
return cached_query(
query,
key,
(q, limit, offset),
serialize_dict_list,
deserialize_dict_list,
reload,
prepare=False,
client_bind=True,
sets_to_fetch=[3],
lock_enabled=True,
)
def get_all_channels_for_server(discord_server, reload=False):
key = f"discord_channels_for_server:{discord_server}"
query = "SELECT channel_id as id, name FROM discord_channels WHERE server_id = %s"
return cached_query(query, key, (discord_server,), reload=reload, ex_on_null=60, lock_enabled=True)
def get_popular_posts_for_date_range(
start_date: datetime,
end_date: datetime,
scale: PeriodScale,
page: int,
per_page: int,
pages_to_query: int,
expiry: int = Configuration().redis["default_ttl"],
reload: bool = False,
) -> list[PostWithFavCount]:
key = f"popular_posts:{scale}:{per_page}:{start_date.isoformat()}-{end_date.isoformat()}"
redis = get_conn()
result = redis.lindex(key, page)
if result:
parsed_result = deserialize_post_list(result)
if parsed_result:
return parsed_result
else:
return []
else:
if page != 0:
result = redis.lindex(key, 0)
if result:
return []
params = (start_date, end_date, pages_to_query * per_page)
order_factor = "COUNT(*)"
if scale == "recent":
order_factor = 'SUM((EXTRACT(EPOCH FROM ("created_at" - %s )) / EXTRACT(EPOCH FROM ( %s - %s )) ))::float'
params = (start_date, end_date, start_date, *params)
query = f"""
WITH "top_faves" AS (
SELECT "service", "post_id", {
order_factor
} as fav_count
FROM "account_post_favorite"
WHERE "created_at" BETWEEN %s AND %s
GROUP BY "service", "post_id"
ORDER BY fav_count DESC
LIMIT %s
)
SELECT
p.id,
p."user",
p.service,
p.title,
substring( p."content", 1, 50),
p.published,
p.file,
p.attachments,
tf."fav_count"
FROM
"top_faves" AS tf
INNER JOIN
"posts" AS p
ON
p."id" = tf."post_id"
AND
p."service" = tf."service";
"""
result = cached_query(
query,
key,
params,
serialize_fn=lambda x: [serialize_post_list(cache_page) for cache_page in batched(x, per_page)],
deserialize_fn=lambda x: list(itertools.chain(*(deserialize_post_list(cache_page) for cache_page in x))),
ex=expiry,
reload=reload,
cache_store_method="rpush",
lock_enabled=True,
)
return (result or [])[(page * per_page) : ((page + 1) * per_page)]
def get_tagged_posts(
tags: list[str], offset: int, limit: int, service: Optional[str] = None, artist_id: Optional[str] = None
) -> list[Post]:
key = f"tagged_posts:{tags}:{service}:{artist_id}:{offset}"
query = """
SELECT
id,
"user",
service,
title,
content,
embed,
shared_file,
(
CASE service
WHEN 'fanbox'
THEN NULL
ELSE added
END
) AS added,
published,
edited,
file,
attachments,
poll,
captions,
tags
FROM
"posts"
WHERE
"tags" @> %s::citext[]
"""
params: tuple[...] = (tags,)
if service and artist_id:
query += """
AND "service" = %s AND "user" = %s ORDER BY published DESC
"""
params += (service, artist_id)
else:
query += " ORDER BY added DESC "
query += "OFFSET %s LIMIT %s"
params += (str(offset), str(limit))
return cached_query(query, key, params)
@dataclass
class Tag:
tag: str
post_count: int
def get_all_tags(service: Optional[str] = None, creator_id: Optional[str] = None) -> list[Tag]:
if creator_id and not service:
raise Exception("Must be used with both creator_id and service")
key = f"tags:{service or ""}:{creator_id or ""}"
query = f"""
SELECT {"tag" if creator_id else "lower(tag)"} as tag, COUNT(1) AS post_count
FROM "posts"
CROSS JOIN UNNEST(tags) AS tag
"""
params: tuple[str, ...] = ()
if service and creator_id:
query += """WHERE "service" = %s AND "user" = %s """
params += (service, creator_id)
query += """
GROUP BY tag
ORDER BY post_count DESC
LIMIT 2000
"""
ex = int(timedelta(hours=(6 if creator_id else 24)).total_seconds())
return cached_query(query, key, params, ex=ex)