commit 6e70db503a3552cc2437a69111882c9c9f9b51b86698a071319253142c51942f Author: admin Date: Tue Apr 14 12:40:40 2026 +0000 Добавить main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..710bb2b --- /dev/null +++ b/main.py @@ -0,0 +1,881 @@ +# FastAPI + +## App + +from fastapi import FastAPI, Depends, HTTPException, Body, Security, Form, Request, Response +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes +from passlib.context import CryptContext +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket +from bson import ObjectId +import redis.asyncio as redis +import jwt +from datetime import datetime, timedelta, timezone +from typing import Dict, Any, List, Set +from pydantic import BaseModel, model_validator +import secrets +import hashlib +import os +import json +from fastapi.templating import Jinja2Templates +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles + + +### Конфигурация + +app = FastAPI(title="FastAPI Auth") +templates = Jinja2Templates(directory="templates") + +app.mount("/static", StaticFiles(directory="static"), name="static") + +PROD = False if os.getenv("PROD") == 'False' else True +SECRET_KEY = os.getenv("SECRET_KEY") +MONGO_INITDB_ROOT_PASSWORD = os.getenv("MONGO_INITDB_ROOT_PASSWORD") +MONGODB_HOST = os.getenv("MONGODB_SERVICE_HOST") +REDIS_HOST = os.getenv("REDIS_SERVICE_HOST") + +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 60 +REFRESH_TOKEN_EXPIRE_DAYS = 170 + +mongo_client = AsyncIOMotorClient(f"mongodb://admin:{MONGO_INITDB_ROOT_PASSWORD}@{MONGODB_HOST}:{27017 if PROD else 30000}/") +db_form = mongo_client["form"] +form_col = db_form['form'] +gridfs = AsyncIOMotorGridFSBucket(db_form) +form_col = mongo_client["form"]['form'] +db = mongo_client["auth"] +users_collection = db["users"] +news_col = mongo_client["back"]['news'] + +redis_client = redis.Redis(host=REDIS_HOST, port=6379 if PROD else 30001, db=0, decode_responses=True) + +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") + + + +### ROLE_SCOPES + +ROLE_SCOPES: Dict[str, List[str]] = { + "user": ["read:items", "create:items"], + "premium": ["read:items", "create:items", "update:items"], + "moderator": ["read:items", "moderate", "update:items", "delete:items"], + "admin": ["read:items", "create:items", "update:items", "delete:items", "moderate", "admin", "api:debug"], +} +BLOCK_ROLES = ['admin'] + + +### Pydantic модели + +class UserRegister(BaseModel): + username: str | None = None + phone: str | None = None + email: str | None = None + password: str + client_id: str + roles: List[str] = ["user"] + extra_scopes: List[str] = [] + + @model_validator(mode='after') + def at_least_one_identifier(self): + if not any([self.username, self.phone, self.email]): + raise ValueError("At least one of username, phone or email is required") + return self + + +class UserLoginForm(OAuth2PasswordRequestForm): + pass + +class LoginRequest(BaseModel): + identifier: str + password: str + client_id: str + +class TokenResponse(BaseModel): + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class RefreshRequest(BaseModel): + refresh_token: str | None = None + client_id: str | None = None + + +### Хелперы + +import orjson +from typing import List +from bson import ObjectId + +from bson import ObjectId + +async def fetch_latest_from_mongo(limit: int = 20) -> List[dict]: + cursor = news_col.find({}).sort("_id", -1).limit(limit) + docs = [] + async for doc in cursor: + # сохранить original id отдельно, затем строковое представление + # doc["_id"] = str(doc.get("_id")) + # если нужен оригинальный ObjectId при дальнейшей обработке, храните его в другом поле + # doc["_oid"] = ObjectId(doc["_id"]) # пример, если надо восстановить + if "published_at" in doc and hasattr(doc["published_at"], "isoformat"): + doc["published_at"] = doc["published_at"].isoformat() + docs.append(doc) + return docs + + +def serialize_for_cache(items: List[dict]) -> bytes: + serializable = [] + for doc in items: + copy = dict(doc) + _id = copy.get("_id") + if isinstance(_id, ObjectId): + copy["_id"] = str(_id) + serializable.append(copy) + return orjson.dumps({"items": serializable}) + +def deserialize_from_cache_to_objectids(payload: bytes) -> List[dict]: + obj = orjson.loads(payload) + items = obj.get("items", []) + for item in items: + _id = item.get("_id") + if isinstance(_id, str): + try: + item["_id"] = ObjectId(_id) + except Exception: + pass + return items + +async def get_latest_news() -> List[dict]: + cached = await redis_client.get('latest_news') + if cached: + try: + # redis client может вернуть str или bytes; orjson.loads принимает bytes/str + return deserialize_from_cache_to_objectids(cached if isinstance(cached, (bytes, str)) else cached.decode()) + except Exception: + pass + + items = await fetch_latest_from_mongo(20) + payload = serialize_for_cache(items) # bytes + # при использовании aioredis: .set accepts bytes + await redis_client.set('latest_news', payload, ex=5) + return items + + +from typing import Optional + +def normalize_phone(phone: Optional[str]) -> Optional[str]: + if not phone or len(phone) < 6: + return None + + cleaned = ''.join(char for char in phone.strip() if char.isdigit() or char == '+') + + if cleaned.startswith('+7'): + return cleaned + + if cleaned.startswith('7'): + return '+' + cleaned + + if cleaned.startswith('8') and len(cleaned) == 11: + return '+7' + cleaned[1:] + + if len(cleaned) >= 10: + return '+7' + cleaned + + +def get_bearer_token(request: Request): + auth = request.headers.get("Authorization") + if auth and auth.startswith("Bearer "): + return auth.split(" ", 1)[1] + + +async def check_bruteforce(username: str): + key = f"login_attempts:{username.lower()}" + attempts = await redis_client.get(key) + if attempts is None: + return + + attempts = int(attempts) + if attempts >= 8: + ttl = await redis_client.ttl(key) + minutes = (ttl + 59) // 60 if ttl > 0 else 15 + raise HTTPException( + status_code=429, + detail=f"Too many failed login attempts. Try again in ~{minutes} minutes." + ) + + +async def increment_failed_attempt(username: str): + key = f"login_attempts:{username.lower()}" + pipe = await redis_client.pipeline() + pipe.incr(key) + pipe.expire(key, 20 * 60) + pipe.execute() + + +async def reset_failed_attempts(username: str): + key = f"login_attempts:{username.lower()}" + await redis_client.delete(key) + + +def verify_password(plain: str, hashed: str) -> bool: + return pwd_context.verify(plain, hashed) + + +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) + + +def create_jwt_token(data: Dict, expires_delta: timedelta) -> str: + to_encode = data.copy() + expire = datetime.now(timezone.utc) + expires_delta + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + +def get_all_user_scopes(roles: List[str], extra_scopes: List[str]) -> List[str]: + scopes: Set[str] = set(extra_scopes) + for role in roles: + if role in ROLE_SCOPES: + scopes.update(ROLE_SCOPES[role]) + return list(scopes) + + +def hash_token(token: str) -> str: + return hashlib.sha256(token.encode()).hexdigest() + + +### CSRF защита + +def generate_csrf_token() -> str: + return secrets.token_urlsafe(32) + + +def set_csrf_cookie(response: Response, csrf_token: str): + response.set_cookie( + key="csrf_token", + value=csrf_token, + httponly=PROD, + secure=PROD, + samesite="lax", + max_age=3600 * 24 + ) + + +async def validate_csrf(request: Request, user_id_str: str, response: Response): + if request.method in ("GET", "HEAD", "OPTIONS", "TRACE"): + return + + cookie_csrf = request.cookies.get("csrf_token") + + if not cookie_csrf: + raise HTTPException(status_code=403, detail="CSRF token mismatch or missing") + + stored_csrf_token = await redis_client.get(f"csrf:{user_id_str}") + + if stored_csrf_token is None or stored_csrf_token != cookie_csrf: + raise HTTPException(status_code=403, detail="CSRF token mismatch or missing") + + csrf_token = generate_csrf_token() + await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token) + set_csrf_cookie(response, csrf_token) + + +### Извлечение access token из cookie или Authorization + +def get_access_token(request: Request) -> str | None: + token = request.cookies.get("access_token") + if token: + return token + auth = request.headers.get("Authorization") + if auth and auth.startswith("Bearer "): + return auth.split(" ")[1] + return None + + +### Получение текущего пользователя + +async def get_current_user( + request: Request, + security_scopes: SecurityScopes +) -> Dict[str, Any]: + token = get_access_token(request) + if not token: + raise HTTPException(status_code=401, detail="Not authenticated") + + key = f"access:{token}" + user_id_str = await redis_client.get(key) + if not user_id_str: + raise HTTPException(status_code=401, detail="Invalid or expired access token") + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + if payload.get("sub") != user_id_str: + raise HTTPException(status_code=401, detail="Token mismatch") + + user = await users_collection.find_one({"_id": ObjectId(user_id_str)}, {"_id": 0, "hashed_password": 0}) + if not user: + raise HTTPException(status_code=401, detail="User not found") + + token_scopes: List[str] = payload.get("scopes", []) + for scope in security_scopes.scopes: + if scope not in token_scopes: + raise HTTPException(status_code=403, detail="Not enough permissions") + + scopes = get_all_user_scopes(user.get("roles", []), user.get("extra_scopes", [])) + user['scopes'] = scopes + user["id"] = user_id_str + return user + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Access token expired") + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token") + + +### Эндпоинты + +@app.post("/register", response_model=dict) +async def register( + request: Request, + user: UserRegister, + response: Response, +): + if user.username and await users_collection.find_one({"username": user.username}): + raise HTTPException(400, "Имя пользователя занято") + if normalize_phone(user.phone) and await users_collection.find_one({"phone": normalize_phone(user.phone)}): + raise HTTPException(400, "Телефон уже зарегестрирован") + if user.email and await users_collection.find_one({"email": user.email.lower()}): + raise HTTPException(400, "Email уже зарегестрирован") + + hashed_password = get_password_hash(user.password) + + allowed_roles = [r for r in user.roles if r not in BLOCK_ROLES] + + doc = { + "username": user.username, + "phone": normalize_phone(user.phone) if user.phone else None, + "email": user.email.lower() if user.email else None, + "hashed_password": hashed_password, + "roles": allowed_roles, + } + + result = await users_collection.insert_one(doc) + user_id_str = str(result.inserted_id) + + scopes = get_all_user_scopes(allowed_roles, []) + + client_id = user.client_id + client_ip = request.client.host + + access_exp_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_exp_delta = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + + access_token = create_jwt_token( + {"sub": user_id_str, "scopes": scopes, "client_id": client_id, "ip": client_ip}, + access_exp_delta + ) + refresh_token = create_jwt_token( + {"sub": user_id_str, "type": "refresh", "client_id": client_id, "ip": client_ip}, + refresh_exp_delta + ) + + await redis_client.setex(f"access:{access_token}", int(access_exp_delta.total_seconds()), user_id_str) + await redis_client.setex(f"refresh:{hash_token(refresh_token)}", int(refresh_exp_delta.total_seconds()), user_id_str) + + session_key = f"user_sessions:{user_id_str}" + await redis_client.sadd(session_key, f"access:{access_token}") + await redis_client.sadd(session_key, f"refresh:{hash_token(refresh_token)}") + + csrf_token = generate_csrf_token() + await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token) + + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=PROD, + samesite="lax", + max_age=int(access_exp_delta.total_seconds()) + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=PROD, + samesite="lax", + max_age=int(refresh_exp_delta.total_seconds()) + ) + set_csrf_cookie(response, csrf_token) + + return { + "ok": True, + "access_token": access_token, + "refresh_token": refresh_token, + "user_id": user_id_str, + "token_type": "bearer" + } + + +@app.post("/login") +async def login( + request: Request, + response: Response, + login_data: LoginRequest = Body(...), +): + identifier = login_data.identifier.strip() + await check_bruteforce(identifier) + + user = await users_collection.find_one({ + "$or": [ + {"username": identifier}, + {"phone": normalize_phone(identifier)}, + {"email": identifier.lower()}, + ] + }) + + if not user or not verify_password(login_data.password, user["hashed_password"]): + await increment_failed_attempt(identifier) + raise HTTPException(401, "Incorrect credentials") + + await reset_failed_attempts(identifier) + + user_id_str = str(user["_id"]) + scopes = get_all_user_scopes(user.get("roles", []), user.get("extra_scopes", [])) + + client_ip = request.client.host + + access_exp_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_exp_delta = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + + access_token = create_jwt_token( + {"sub": user_id_str, "scopes": scopes, "client_id": login_data.client_id, "ip": client_ip}, + access_exp_delta + ) + refresh_token = create_jwt_token( + {"sub": user_id_str, "type": "refresh", "client_id": login_data.client_id, "ip": client_ip}, + refresh_exp_delta + ) + + await redis_client.setex(f"access:{access_token}", int(access_exp_delta.total_seconds()), user_id_str) + await redis_client.setex(f"refresh:{hash_token(refresh_token)}", int(refresh_exp_delta.total_seconds()), user_id_str) + + session_key = f"user_sessions:{user_id_str}" + await redis_client.sadd(session_key, f"access:{access_token}") + await redis_client.sadd(session_key, f"refresh:{hash_token(refresh_token)}") + + csrf_token = generate_csrf_token() + await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token) + + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=PROD, + samesite="lax", + max_age=int(access_exp_delta.total_seconds()) + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=False, + samesite="lax", + max_age=int(refresh_exp_delta.total_seconds()) + ) + set_csrf_cookie(response, csrf_token) + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "user_id": user_id_str, + "token_type": "bearer" + } + + +@app.post("/refresh", response_model=TokenResponse) +async def refresh_token( + request: Request, + response: Response, + req: RefreshRequest +): + refresh_token = ( + request.cookies.get("refresh_token") or + get_bearer_token(request) or + (req.refresh_token if req else None) + ) + + if not refresh_token: + raise HTTPException(401, "Refresh token required") + + client_id = None + + if req: + client_id = req.client_id + + refresh_hash = hash_token(refresh_token) + user_id_str = await redis_client.get(f"refresh:{refresh_hash}") + + if not user_id_str: + raise HTTPException(status_code=401, detail="Invalid or expired refresh token") + + try: + payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM]) + print(payload,user_id_str) + if payload.get("type") != "refresh": + raise HTTPException(status_code=401, detail="Not a refresh token") + if payload.get("sub") != user_id_str: + raise HTTPException(status_code=401, detail="Token mismatch") + except jwt.InvalidTokenError as e: + raise HTTPException(status_code=401, detail=f"Invalid refresh token: {str(e)}") + + await redis_client.delete(f"refresh:{refresh_hash}") + user = await users_collection.find_one({"_id": ObjectId(user_id_str)}) + + scopes = get_all_user_scopes(user.get("roles", []), user.get("extra_scopes", [])) + + client_ip = request.client.host + print(client_ip) + + access_exp = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_exp = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) + + new_access = create_jwt_token( + {"sub": user_id_str, "scopes": scopes, "client_id": client_id, "ip": client_ip}, + access_exp + ) + new_refresh = create_jwt_token( + {"sub": user_id_str, "type": "refresh", "client_id": client_id, "ip": client_ip}, + refresh_exp + ) + + new_refresh_hash = hash_token(new_refresh) + await redis_client.setex(f"access:{new_access}", int(access_exp.total_seconds()), user_id_str) + await redis_client.setex(f"refresh:{new_refresh_hash}", int(refresh_exp.total_seconds()), user_id_str) + + session_key = f"user_sessions:{user_id_str}" + await redis_client.sadd(session_key, f"access:{new_access}") + await redis_client.sadd(session_key, f"refresh:{new_refresh_hash}") + + csrf_token = generate_csrf_token() + await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token) + + response.set_cookie( + key="access_token", + value=new_access, + httponly=True, + secure=True, + samesite="strict", + max_age=int(access_exp.total_seconds()) + ) + + response.set_cookie( + key="refresh_token", + value=new_refresh, + httponly=True, + secure=True, + samesite="strict", + max_age=int(refresh_exp.total_seconds()) + ) + + set_csrf_cookie(response, csrf_token) + + return { + "access_token": new_access, + "token_type": "bearer", + "refresh_token": new_refresh + } + +@app.post("/logout") +async def logout( + request: Request, + response: Response, + current_user: Dict = Depends(get_current_user) +): + await validate_csrf(request) + + user_id_str = current_user["id"] + + # Удаляем все сессии + session_key = f"user_sessions:{user_id_str}" + tokens = await redis_client.smembers(session_key) + for t in tokens: + await redis_client.delete(t) + await redis_client.delete(session_key) + + # Удаляем csrf + await redis_client.delete(f"csrf:{user_id_str}") + + # Удаляем cookie (expires в прошлом) + response.delete_cookie("access_token") + response.delete_cookie("refresh_token") + response.delete_cookie("csrf_token") + + return {"message": "Logged out from all sessions"} + + +@app.get("/me") +async def get_me( + current_user: Dict = Security(get_current_user, scopes=["read:items"]) +): + return current_user + + +@app.post("/items") +async def create_item( + request: Request, + response: Response, + current_user: Dict = Security(get_current_user, scopes=["create:items"]) +): + await validate_csrf(request, current_user["id"], response) + return {"message": f"Item created by {current_user['username']}"} + + +@app.get("/healthz") +async def healthz(): + return {} + +@app.get("/") +async def main(request: Request): + cookies = request.cookies + headers = request.headers + return {"cookies": cookies, "headers": headers} + +@app.get("/form", response_class=HTMLResponse) +async def form_(request: Request): + return templates.TemplateResponse( + request=request, name="form.html", context={"id": 1} + ) + +from bson import ObjectId +@app.get("/test", response_class=HTMLResponse) +async def read_item(request: Request): + news = await get_latest_news() + return templates.TemplateResponse( + request=request, name="index.html", context={"news": news, "posts": [{'_id':ObjectId(), 'title':'Title', 'description':'Description', 'user': {'name':" F I O", 'grade': 'No grade'}, 'media': [{'url':'/static/logo.jpeg'}]},{'_id':ObjectId(), 'media': [{'url':'/static/logo.jpeg'}]}]} + ) + +import os +import uuid +from fastapi import FastAPI, UploadFile, File, Form, HTTPException +from fastapi.responses import JSONResponse +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket +from pydantic import BaseModel +from typing import Optional + + +class FormOut(BaseModel): + id: str + title: str + topic: Optional[str] + short_desc: Optional[str] + rules_file_id: Optional[str] + poster_file_id: Optional[str] + min_paintings: Optional[int] + max_paintings: Optional[int] + selected_count: Optional[int] + +@app.post("/api", response_model=FormOut) +async def receive_form( + poster: Optional[UploadFile] = File(None), + title: str = Form(...), + topic: Optional[str] = Form(None), + short_desc: Optional[str] = Form(None), + rules: Optional[UploadFile] = File(None), + min_paintings: Optional[int] = Form(None), + max_paintings: Optional[int] = Form(None), + selected_count: Optional[int] = Form(None), +): + # helper to save file to GridFS and return file id (as str) + async def save_to_gridfs(upload: UploadFile): + if upload is None: + return None + filename = upload.filename or str(uuid.uuid4()) + # read bytes + contents = await upload.read() + # upload to GridFS + file_id = await gridfs.upload_from_stream(filename, contents) + return str(file_id) + + try: + poster_id = await save_to_gridfs(poster) + rules_id = await save_to_gridfs(rules) + + doc = { + "title": title, + "topic": topic, + "short_desc": short_desc, + "poster_file_id": poster_id, + "rules_file_id": rules_id, + "min_paintings": int(min_paintings) if min_paintings is not None else None, + "max_paintings": int(max_paintings) if max_paintings is not None else None, + "selected_count": int(selected_count) if selected_count is not None else None, + } + + res = await form_col.insert_one(doc) + doc_out = { + "id": str(res.inserted_id), + **{k: v for k, v in doc.items()} + } + print(doc_out) + return JSONResponse(status_code=201, content=FormOut(**doc_out).dict()) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +from fastapi.responses import StreamingResponse +from bson import ObjectId +import mimetypes +from typing import List + +# helper: stream file from GridFS by id (str) +async def stream_gridfs_file(file_id: str): + if file_id is None: + return None + try: + oid = ObjectId(file_id) + except Exception: + return None + # open download stream + stream = await gridfs.open_download_stream(oid) + # get file metadata + filename = stream.filename or "file" + content_type = stream.content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + # async generator to read chunks + async def file_iterator(chunk_size: int = 1024 * 64): + while True: + chunk = await stream.read(chunk_size) + if not chunk: + break + yield chunk + await stream.close() + return filename, content_type, file_iterator + +from fastapi.responses import StreamingResponse +from bson import ObjectId +import mimetypes +from typing import List + +from urllib.parse import quote + +def content_disposition_header(filename: str, inline: bool = True) -> str: + dispo_type = "inline" if inline else "attachment" + # sanitized ASCII fallback + try: + filename_ascii = filename.encode("latin-1") + # if succeeds, return simple header + return f'{dispo_type}; filename="{filename}"' + except UnicodeEncodeError: + # RFC 5987: filename* with UTF-8''percent-encoded + filename_quoted = quote(filename, safe='') + return f"{dispo_type}; filename*=UTF-8''{filename_quoted}" + +from random import randint + +@app.get("/files/{file_id}") +async def get_file(file_id: str): + # if request asks for random file + if file_id == "random": + files_col = db_form["fs.files"] + total = await files_col.count_documents({}) + if total == 0: + raise HTTPException(status_code=404, detail="No files available") + # pick random skip index + idx = randint(0, max(0, total - 1)) + cursor = files_col.find().skip(idx).limit(1) + doc = await cursor.to_list(length=1) + if not doc: + raise HTTPException(status_code=404, detail="No file found") + file_id = str(doc[0]["_id"]) + + res = await stream_gridfs_file(file_id) + if res is None: + raise HTTPException(status_code=404, detail="File not found") + filename, content_type, iterator = res + headers = {"Content-Disposition": content_disposition_header(filename, inline=True)} + return StreamingResponse(iterator(), media_type=content_type, headers=headers) + + +@app.get("/documents/{doc_id}") +async def get_document(doc_id: str): + try: + oid = ObjectId(doc_id) + except Exception: + raise HTTPException(status_code=400, detail="Invalid document id") + doc = await form_col.find_one({"_id": oid}) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + # construct download/view urls (adjust host/path if needed) + def make_url(fid): + return f"/files/{fid}" if fid else None + return { + "id": str(doc["_id"]), + "title": doc.get("title"), + "topic": doc.get("topic"), + "short_desc": doc.get("short_desc"), + "poster_url": make_url(doc.get("poster_file_id")), + "rules_url": make_url(doc.get("rules_file_id")), + "min_paintings": doc.get("min_paintings"), + "max_paintings": doc.get("max_paintings"), + "selected_count": doc.get("selected_count"), + "task_id": doc.get("task_id"), + } + + +@app.get("/documents/{doc_id}") +async def get_document(doc_id: str): + try: + oid = ObjectId(doc_id) + except Exception: + raise HTTPException(status_code=400, detail="Invalid document id") + doc = await form_col.find_one({"_id": oid}) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + # construct download/view urls (adjust host/path if needed) + def make_url(fid): + return f"/files/{fid}" if fid else None + return { + "id": str(doc["_id"]), + "title": doc.get("title"), + "topic": doc.get("topic"), + "short_desc": doc.get("short_desc"), + "poster_url": make_url(doc.get("poster_file_id")), + "rules_url": make_url(doc.get("rules_file_id")), + "min_paintings": doc.get("min_paintings"), + "max_paintings": doc.get("max_paintings"), + "selected_count": doc.get("selected_count"), + "task_id": doc.get("task_id"), + } + + +from fastapi import Query + +@app.get("/files") +async def list_gridfs_files( + filename: Optional[str] = Query(None, description="Filter by filename (substring)"), + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), +): + q = {} + if filename: + q["filename"] = {"$regex": filename, "$options": "i"} + # GridFS files collection is .fs.files + files_col = db_form["fs.files"] + cursor = files_col.find(q).sort("uploadDate", -1).skip(skip).limit(limit) + items = [] + async for f in cursor: + items.append({ + "id": str(f["_id"]), + "filename": f.get("filename"), + "length": f.get("length"), + "uploadDate": f.get("uploadDate").isoformat() if f.get("uploadDate") else None, + "contentType": f.get("contentType"), + "metadata": f.get("metadata"), + }) + total = await files_col.count_documents(q) + return {"total": total, "skip": skip, "limit": limit, "items": items} + + +# news_col = mongo_client["back"]["posts"] +# async for i in news_col.find({}).limit(20): +# print(i) \ No newline at end of file