403 lines
10 KiB
Python
403 lines
10 KiB
Python
import base64
|
|
import itertools
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, TypedDict, Any
|
|
|
|
from src.config import Configuration
|
|
from src.internals.cache.redis import get_conn
|
|
from src.internals.database.database import cached_count, cached_query
|
|
from src.internals.serializers.generic_with_dates import deserialize_dict_list, serialize_dict_list
|
|
from src.internals.serializers.post import deserialize_post_list, serialize_post_list
|
|
from src.utils.datetime_ import PeriodScale
|
|
from src.utils.utils import batched
|
|
|
|
|
|
class Post(TypedDict):
|
|
id: str
|
|
user: str
|
|
service: str
|
|
title: str
|
|
content: str
|
|
embed: dict
|
|
shared_file: bool
|
|
added: datetime
|
|
published: datetime
|
|
edited: datetime
|
|
file: dict
|
|
attachments: list[dict]
|
|
poll: dict
|
|
captions: Any
|
|
tags: list[str]
|
|
incomplete_rewards: Optional[str]
|
|
|
|
|
|
class PostWithFavCount(Post):
|
|
fav_count: int
|
|
|
|
POST_FLAG_REASON_NUMBER_TO_SLUG = {
|
|
-2: "delete-copyright",
|
|
-1: "delete-abuse",
|
|
1: "missing-password",
|
|
2: "offsite-expired",
|
|
10: "post-changed",
|
|
20: "corrupted-files",
|
|
21: "missing-files",
|
|
11: "stale-comments",
|
|
12: "formatting-error",
|
|
8: "reason-other",
|
|
}
|
|
|
|
POST_FLAG_REASON_SLUG_TO_NUMBER = {v:k for k,v in POST_FLAG_REASON_NUMBER_TO_SLUG.items()}
|
|
|
|
POST_FLAG_CUT_OFF = 0
|
|
|
|
def count_all_posts(reload=False) -> int:
|
|
key = "global_post_count"
|
|
query = 'SELECT COUNT(*) FROM posts WHERE ("user", service) NOT IN (SELECT id, service from dnp)'
|
|
|
|
return cached_count(query, key, reload=reload, ex=6000, lock_enabled=True)
|
|
|
|
|
|
def count_all_posts_for_query(q: str, reload=False) -> int:
|
|
q = " OR ".join(x.lower() for x in q.strip().split(" OR "))
|
|
|
|
if q == "":
|
|
return count_all_posts(reload=reload)
|
|
|
|
key = f"global_post_count_for_query:{base64.b64encode(q.encode()).decode()}"
|
|
query = """
|
|
BEGIN;
|
|
SET LOCAL random_page_cost = 0.0001;
|
|
SET LOCAL statement_timeout = 10000;
|
|
SELECT COUNT(*)
|
|
FROM posts
|
|
WHERE (title || ' ' || content) &@~ %s
|
|
AND ("user", service) NOT IN (
|
|
SELECT id, service
|
|
FROM dnp
|
|
);
|
|
COMMIT;
|
|
"""
|
|
|
|
return cached_count(query, key, (q,), reload, prepare=False, client_bind=True, sets_to_fetch=[3], lock_enabled=True)
|
|
|
|
|
|
def count_all_posts_for_tag(tags: list[str], service: Optional[str] = None, artist_id: Optional[str] = None) -> int:
|
|
b = base64.b64encode(f"==TAG==\0{tags}".encode()).decode()
|
|
key = f"global_post_count_for_query:{b}"
|
|
query = """
|
|
SELECT
|
|
COUNT(*)
|
|
FROM
|
|
POSTS
|
|
WHERE
|
|
"tags" @> %s::citext[]
|
|
"""
|
|
params = (tags,)
|
|
|
|
if service and artist_id:
|
|
query += """
|
|
AND "service" = %s AND "user" = %s
|
|
"""
|
|
params += (service, artist_id)
|
|
|
|
return cached_count(query, key, params)
|
|
|
|
|
|
def get_all_posts_summary(offset: int, limit=50, reload=False, cache_ttl=None):
|
|
# we need this version to reduce redis size and bandwidth in half
|
|
key = f"all_posts:summary:{limit}:{offset}"
|
|
query = """
|
|
SELECT
|
|
id,
|
|
"user",
|
|
service,
|
|
title,
|
|
substring("content", 1, 50),
|
|
published,
|
|
file,
|
|
attachments
|
|
FROM
|
|
posts
|
|
WHERE
|
|
("user", service) NOT IN (
|
|
SELECT id, service from dnp
|
|
)
|
|
ORDER BY
|
|
added DESC
|
|
OFFSET %s
|
|
LIMIT %s
|
|
"""
|
|
extra = {}
|
|
|
|
if cache_ttl:
|
|
extra["ex"] = cache_ttl
|
|
|
|
return cached_query(
|
|
query, key, (offset, limit), serialize_dict_list, deserialize_dict_list, reload, lock_enabled=True, **extra
|
|
)
|
|
|
|
|
|
def get_all_posts_full(offset: int, limit=50, reload=False):
|
|
key = f"all_posts:full:{limit}:{offset}"
|
|
query = """
|
|
SELECT
|
|
id,
|
|
"user",
|
|
service,
|
|
title,
|
|
content,
|
|
embed,
|
|
shared_file,
|
|
(
|
|
CASE service
|
|
WHEN 'fanbox'
|
|
THEN NULL
|
|
ELSE added
|
|
END
|
|
) AS added,
|
|
published,
|
|
edited,
|
|
file,
|
|
attachments,
|
|
poll,
|
|
captions,
|
|
tags
|
|
FROM
|
|
posts
|
|
WHERE
|
|
("user", service) NOT IN (
|
|
SELECT
|
|
id,
|
|
service
|
|
FROM
|
|
dnp
|
|
)
|
|
ORDER BY
|
|
added DESC
|
|
OFFSET %s
|
|
LIMIT %s
|
|
"""
|
|
|
|
return cached_query(
|
|
query, key, (offset, limit), serialize_dict_list, deserialize_dict_list, reload, lock_enabled=True
|
|
)
|
|
|
|
|
|
def get_all_posts_for_query(q: str, offset: int, limit=50, reload=False):
|
|
q = " OR ".join(x.lower() for x in q.strip().split(" OR "))
|
|
|
|
if q == "":
|
|
return get_all_posts_summary(0, limit, reload, cache_ttl=Configuration().cache_ttl_for_recent_posts)
|
|
|
|
key = f"all_posts_for_query:{base64.b64encode(q.encode()).decode()}:{limit}:{offset}"
|
|
query = """
|
|
BEGIN;
|
|
SET LOCAL random_page_cost = 0.0001;
|
|
SET LOCAL statement_timeout = 10000;
|
|
SELECT
|
|
id,
|
|
"user",
|
|
service,
|
|
title,
|
|
substring("content", 1, 50),
|
|
published,
|
|
file,
|
|
attachments
|
|
FROM
|
|
posts
|
|
WHERE
|
|
(title || ' ' || content) &@~ %s
|
|
AND
|
|
("user", service) NOT IN (
|
|
SELECT id, service
|
|
FROM dnp
|
|
)
|
|
ORDER BY
|
|
added DESC
|
|
LIMIT %s
|
|
OFFSET %s;
|
|
COMMIT;
|
|
"""
|
|
|
|
return cached_query(
|
|
query,
|
|
key,
|
|
(q, limit, offset),
|
|
serialize_dict_list,
|
|
deserialize_dict_list,
|
|
reload,
|
|
prepare=False,
|
|
client_bind=True,
|
|
sets_to_fetch=[3],
|
|
lock_enabled=True,
|
|
)
|
|
|
|
|
|
def get_all_channels_for_server(discord_server, reload=False):
|
|
key = f"discord_channels_for_server:{discord_server}"
|
|
query = "SELECT channel_id as id, name FROM discord_channels WHERE server_id = %s"
|
|
|
|
return cached_query(query, key, (discord_server,), reload=reload, ex_on_null=60, lock_enabled=True)
|
|
|
|
|
|
def get_popular_posts_for_date_range(
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
scale: PeriodScale,
|
|
page: int,
|
|
per_page: int,
|
|
pages_to_query: int,
|
|
expiry: int = Configuration().redis["default_ttl"],
|
|
reload: bool = False,
|
|
) -> list[PostWithFavCount]:
|
|
key = f"popular_posts:{scale}:{per_page}:{start_date.isoformat()}-{end_date.isoformat()}"
|
|
|
|
redis = get_conn()
|
|
result = redis.lindex(key, page)
|
|
|
|
if result:
|
|
parsed_result = deserialize_post_list(result)
|
|
if parsed_result:
|
|
return parsed_result
|
|
else:
|
|
return []
|
|
else:
|
|
if page != 0:
|
|
result = redis.lindex(key, 0)
|
|
if result:
|
|
return []
|
|
|
|
params = (start_date, end_date, pages_to_query * per_page)
|
|
order_factor = "COUNT(*)"
|
|
|
|
if scale == "recent":
|
|
order_factor = 'SUM((EXTRACT(EPOCH FROM ("created_at" - %s )) / EXTRACT(EPOCH FROM ( %s - %s )) ))::float'
|
|
params = (start_date, end_date, start_date, *params)
|
|
|
|
query = f"""
|
|
WITH "top_faves" AS (
|
|
SELECT "service", "post_id", {
|
|
order_factor
|
|
} as fav_count
|
|
FROM "account_post_favorite"
|
|
WHERE "created_at" BETWEEN %s AND %s
|
|
GROUP BY "service", "post_id"
|
|
ORDER BY fav_count DESC
|
|
LIMIT %s
|
|
)
|
|
SELECT
|
|
p.id,
|
|
p."user",
|
|
p.service,
|
|
p.title,
|
|
substring( p."content", 1, 50),
|
|
p.published,
|
|
p.file,
|
|
p.attachments,
|
|
tf."fav_count"
|
|
FROM
|
|
"top_faves" AS tf
|
|
INNER JOIN
|
|
"posts" AS p
|
|
ON
|
|
p."id" = tf."post_id"
|
|
AND
|
|
p."service" = tf."service";
|
|
"""
|
|
|
|
result = cached_query(
|
|
query,
|
|
key,
|
|
params,
|
|
serialize_fn=lambda x: [serialize_post_list(cache_page) for cache_page in batched(x, per_page)],
|
|
deserialize_fn=lambda x: list(itertools.chain(*(deserialize_post_list(cache_page) for cache_page in x))),
|
|
ex=expiry,
|
|
reload=reload,
|
|
cache_store_method="rpush",
|
|
lock_enabled=True,
|
|
)
|
|
|
|
return (result or [])[(page * per_page) : ((page + 1) * per_page)]
|
|
|
|
|
|
def get_tagged_posts(
|
|
tags: list[str], offset: int, limit: int, service: Optional[str] = None, artist_id: Optional[str] = None
|
|
) -> list[Post]:
|
|
key = f"tagged_posts:{tags}:{service}:{artist_id}:{offset}"
|
|
query = """
|
|
SELECT
|
|
id,
|
|
"user",
|
|
service,
|
|
title,
|
|
content,
|
|
embed,
|
|
shared_file,
|
|
(
|
|
CASE service
|
|
WHEN 'fanbox'
|
|
THEN NULL
|
|
ELSE added
|
|
END
|
|
) AS added,
|
|
published,
|
|
edited,
|
|
file,
|
|
attachments,
|
|
poll,
|
|
captions,
|
|
tags
|
|
FROM
|
|
"posts"
|
|
WHERE
|
|
"tags" @> %s::citext[]
|
|
"""
|
|
params: tuple[...] = (tags,)
|
|
|
|
if service and artist_id:
|
|
query += """
|
|
AND "service" = %s AND "user" = %s ORDER BY published DESC
|
|
"""
|
|
params += (service, artist_id)
|
|
else:
|
|
query += " ORDER BY added DESC "
|
|
|
|
query += "OFFSET %s LIMIT %s"
|
|
params += (str(offset), str(limit))
|
|
|
|
return cached_query(query, key, params)
|
|
|
|
|
|
@dataclass
|
|
class Tag:
|
|
tag: str
|
|
post_count: int
|
|
|
|
|
|
def get_all_tags(service: Optional[str] = None, creator_id: Optional[str] = None) -> list[Tag]:
|
|
if creator_id and not service:
|
|
raise Exception("Must be used with both creator_id and service")
|
|
|
|
key = f"tags:{service or ""}:{creator_id or ""}"
|
|
query = f"""
|
|
SELECT {"tag" if creator_id else "lower(tag)"} as tag, COUNT(1) AS post_count
|
|
FROM "posts"
|
|
CROSS JOIN UNNEST(tags) AS tag
|
|
"""
|
|
params: tuple[str, ...] = ()
|
|
|
|
if service and creator_id:
|
|
query += """WHERE "service" = %s AND "user" = %s """
|
|
params += (service, creator_id)
|
|
|
|
query += """
|
|
GROUP BY tag
|
|
ORDER BY post_count DESC
|
|
LIMIT 2000
|
|
"""
|
|
ex = int(timedelta(hours=(6 if creator_id else 24)).total_seconds())
|
|
|
|
return cached_query(query, key, params, ex=ex)
|