kemono2/src/pages/api/v1/importer.py
2025-04-02 16:32:47 +02:00

193 lines
5.9 KiB
Python

import base64
import logging
import re
from typing import TypedDict, Union, Literal, NotRequired
import orjson
from flask import make_response, request, session, jsonify
from src.config import Configuration
from src.internals.cache.redis import get_conn
from src.internals.database.database import query_db, query_one_db, query_rowcount_db
from src.lib.imports import validate_import_key
from src.lib.api import create_client_error_response
from src.pages.api.v1 import v1api_bp
TDOnlyFansImportCreateBody = TypedDict(
"TDOnlyFansImportCreateBody",
{
"service": Literal["onlyfans"],
"session_key": str,
"auto_import": str | int | None,
"save_session_key": str | int | None,
"x-bc": str,
"auth_id": str,
"user_agent": str,
},
)
class TDPatreonImportCreateBody(TypedDict):
service: Literal["patreon"]
session_key: str
auto_import: str | int | None
save_session_key: str | int | None
save_dms: NotRequired[bool]
class TDDiscordImportCreateBody(TypedDict):
service: Literal["discord"]
session_key: str
auto_import: str | int | None
save_session_key: str | int | None
channel_ids: str
@v1api_bp.post("/importer/submit")
def importer_submit():
"""
TODO: split into per-service endpoints
"""
body: Union[TDOnlyFansImportCreateBody, TDPatreonImportCreateBody, TDDiscordImportCreateBody] = request.get_json()
account_id = session.get("account_id")
session_key = body.get("session_key")
auto_import = body.get("auto_import")
save_session_key = body.get("save_session_key")
country = request.headers.get(Configuration().webserver["country_header_key"])
user_agent = request.headers.get("User-Agent")
save_dms = None
key = session_key.strip().strip("\" \t'")
discord_channels = None
result = None
if not session_key:
return create_client_error_response("Session key missing.", 401)
if not body.get("service"):
return create_client_error_response("Service is required.", 400)
# per service validation
if body["service"] == "patreon":
save_dms = body.get("save_dms")
if not account_id and save_dms:
return create_client_error_response("You must be logged in to import direct messages.", 401)
elif body["service"] == "onlyfans":
xBC = body["x-bc"].strip().strip("\" \t'")
auth_id = body["auth_id"].strip().strip("\" \t'")
of_user_agent = body["user_agent"].strip().strip("\" \t'")
key_dict = {
"sess": key,
"x-bc": xBC,
"auth_id": auth_id,
"auth_uid_": "None",
"user_agent": of_user_agent,
}
key = base64.b64encode(orjson.dumps(key_dict)).decode()
elif body["service"] == "discord":
channel_ids = body["channel_ids"]
regex = r"https://discord\.com/channels/\d+/(?P<ch>\d+)"
if not channel_ids:
return create_client_error_response("Channel IDs is required.")
temp_input_channels = [
re.match(regex, item).group("ch") if re.match(regex, item) else item for item in channel_ids.split(",")
]
discord_channels = list(
s.strip() for s in re.split(r"[\s,.、。/']", ",".join(temp_input_channels)) if s.strip()
)
if any(not s.isdigit() for s in discord_channels):
msg = "Discord channel ids are numbers, the last number of the url (notice the / between the 2 numbers)"
logging.exception(msg, extra=dict(input_channels=channel_ids, discord_channels=discord_channels))
return create_client_error_response(msg, 422)
if not discord_channels:
msg = "Discord submit requires channels"
logging.exception(msg, extra=dict(input_channels=channel_ids, discord_channels=discord_channels))
return create_client_error_response(msg, 422)
discord_channels = ",".join(discord_channels)
result = validate_import_key(key, body["service"])
if not result.is_valid:
return "\n".join(result.errors), 422
formatted_key = result.modified_result if result.modified_result else key
service = body["service"]
queue_name = f"import:{service}"
existing_imports = query_db(
b"""
SELECT job_id
FROM jobs
WHERE finished_at IS null
AND queue_name = %s
AND job_input ->> 'key' = %s
""",
(queue_name, formatted_key),
)
if existing_imports:
existing_import = existing_imports[0]["job_id"]
_update_count = query_rowcount_db(
f"""
UPDATE public.jobs
SET priority = LEAST(priority, 1) - 1
WHERE job_id = %s;
""",
(str(existing_import),),
)
response = make_response(jsonify(import_id=existing_import), 200)
return response
data = dict(
key=formatted_key,
service=service,
channel_ids=discord_channels,
auto_import=auto_import,
save_session_key=save_session_key,
save_dms=save_dms,
contributor_id=account_id,
priority=1,
country=country,
user_agent=user_agent,
)
query = b"""
INSERT INTO jobs
(queue_name, priority, job_input)
VALUES
(%s, %s, %s)
RETURNING
job_id;
"""
import_id = query_one_db(query, (queue_name, 1, orjson.dumps(data).decode()))["job_id"]
response = make_response(jsonify(import_id=import_id), 200)
return response
@v1api_bp.route("/importer/logs/<import_id>")
def get_importer_logs(import_id: str):
redis = get_conn()
key = f"importer_logs:{import_id}"
llen = redis.llen(key)
messages = []
if llen > 0:
messages = redis.lrange(key, 0, llen)
redis.expire(key, 60 * 60 * 48)
return orjson.dumps([msg.decode() for msg in messages]), 200