Добавить main.py
This commit is contained in:
881
main.py
Normal file
881
main.py
Normal file
@@ -0,0 +1,881 @@
|
|||||||
|
# FastAPI
|
||||||
|
|
||||||
|
## App
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Depends, HTTPException, Body, Security, Form, Request, Response
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
|
||||||
|
from bson import ObjectId
|
||||||
|
import redis.asyncio as redis
|
||||||
|
import jwt
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Dict, Any, List, Set
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
|
||||||
|
### Конфигурация
|
||||||
|
|
||||||
|
app = FastAPI(title="FastAPI Auth")
|
||||||
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
|
||||||
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||||
|
|
||||||
|
PROD = False if os.getenv("PROD") == 'False' else True
|
||||||
|
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||||
|
MONGO_INITDB_ROOT_PASSWORD = os.getenv("MONGO_INITDB_ROOT_PASSWORD")
|
||||||
|
MONGODB_HOST = os.getenv("MONGODB_SERVICE_HOST")
|
||||||
|
REDIS_HOST = os.getenv("REDIS_SERVICE_HOST")
|
||||||
|
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = 60
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS = 170
|
||||||
|
|
||||||
|
mongo_client = AsyncIOMotorClient(f"mongodb://admin:{MONGO_INITDB_ROOT_PASSWORD}@{MONGODB_HOST}:{27017 if PROD else 30000}/")
|
||||||
|
db_form = mongo_client["form"]
|
||||||
|
form_col = db_form['form']
|
||||||
|
gridfs = AsyncIOMotorGridFSBucket(db_form)
|
||||||
|
form_col = mongo_client["form"]['form']
|
||||||
|
db = mongo_client["auth"]
|
||||||
|
users_collection = db["users"]
|
||||||
|
news_col = mongo_client["back"]['news']
|
||||||
|
|
||||||
|
redis_client = redis.Redis(host=REDIS_HOST, port=6379 if PROD else 30001, db=0, decode_responses=True)
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### ROLE_SCOPES
|
||||||
|
|
||||||
|
ROLE_SCOPES: Dict[str, List[str]] = {
|
||||||
|
"user": ["read:items", "create:items"],
|
||||||
|
"premium": ["read:items", "create:items", "update:items"],
|
||||||
|
"moderator": ["read:items", "moderate", "update:items", "delete:items"],
|
||||||
|
"admin": ["read:items", "create:items", "update:items", "delete:items", "moderate", "admin", "api:debug"],
|
||||||
|
}
|
||||||
|
BLOCK_ROLES = ['admin']
|
||||||
|
|
||||||
|
|
||||||
|
### Pydantic модели
|
||||||
|
|
||||||
|
class UserRegister(BaseModel):
|
||||||
|
username: str | None = None
|
||||||
|
phone: str | None = None
|
||||||
|
email: str | None = None
|
||||||
|
password: str
|
||||||
|
client_id: str
|
||||||
|
roles: List[str] = ["user"]
|
||||||
|
extra_scopes: List[str] = []
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def at_least_one_identifier(self):
|
||||||
|
if not any([self.username, self.phone, self.email]):
|
||||||
|
raise ValueError("At least one of username, phone or email is required")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class UserLoginForm(OAuth2PasswordRequestForm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
identifier: str
|
||||||
|
password: str
|
||||||
|
client_id: str
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str | None = None
|
||||||
|
client_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
### Хелперы
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from typing import List
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
async def fetch_latest_from_mongo(limit: int = 20) -> List[dict]:
|
||||||
|
cursor = news_col.find({}).sort("_id", -1).limit(limit)
|
||||||
|
docs = []
|
||||||
|
async for doc in cursor:
|
||||||
|
# сохранить original id отдельно, затем строковое представление
|
||||||
|
# doc["_id"] = str(doc.get("_id"))
|
||||||
|
# если нужен оригинальный ObjectId при дальнейшей обработке, храните его в другом поле
|
||||||
|
# doc["_oid"] = ObjectId(doc["_id"]) # пример, если надо восстановить
|
||||||
|
if "published_at" in doc and hasattr(doc["published_at"], "isoformat"):
|
||||||
|
doc["published_at"] = doc["published_at"].isoformat()
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_for_cache(items: List[dict]) -> bytes:
|
||||||
|
serializable = []
|
||||||
|
for doc in items:
|
||||||
|
copy = dict(doc)
|
||||||
|
_id = copy.get("_id")
|
||||||
|
if isinstance(_id, ObjectId):
|
||||||
|
copy["_id"] = str(_id)
|
||||||
|
serializable.append(copy)
|
||||||
|
return orjson.dumps({"items": serializable})
|
||||||
|
|
||||||
|
def deserialize_from_cache_to_objectids(payload: bytes) -> List[dict]:
|
||||||
|
obj = orjson.loads(payload)
|
||||||
|
items = obj.get("items", [])
|
||||||
|
for item in items:
|
||||||
|
_id = item.get("_id")
|
||||||
|
if isinstance(_id, str):
|
||||||
|
try:
|
||||||
|
item["_id"] = ObjectId(_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def get_latest_news() -> List[dict]:
|
||||||
|
cached = await redis_client.get('latest_news')
|
||||||
|
if cached:
|
||||||
|
try:
|
||||||
|
# redis client может вернуть str или bytes; orjson.loads принимает bytes/str
|
||||||
|
return deserialize_from_cache_to_objectids(cached if isinstance(cached, (bytes, str)) else cached.decode())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
items = await fetch_latest_from_mongo(20)
|
||||||
|
payload = serialize_for_cache(items) # bytes
|
||||||
|
# при использовании aioredis: .set accepts bytes
|
||||||
|
await redis_client.set('latest_news', payload, ex=5)
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def normalize_phone(phone: Optional[str]) -> Optional[str]:
|
||||||
|
if not phone or len(phone) < 6:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cleaned = ''.join(char for char in phone.strip() if char.isdigit() or char == '+')
|
||||||
|
|
||||||
|
if cleaned.startswith('+7'):
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
if cleaned.startswith('7'):
|
||||||
|
return '+' + cleaned
|
||||||
|
|
||||||
|
if cleaned.startswith('8') and len(cleaned) == 11:
|
||||||
|
return '+7' + cleaned[1:]
|
||||||
|
|
||||||
|
if len(cleaned) >= 10:
|
||||||
|
return '+7' + cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def get_bearer_token(request: Request):
|
||||||
|
auth = request.headers.get("Authorization")
|
||||||
|
if auth and auth.startswith("Bearer "):
|
||||||
|
return auth.split(" ", 1)[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def check_bruteforce(username: str):
|
||||||
|
key = f"login_attempts:{username.lower()}"
|
||||||
|
attempts = await redis_client.get(key)
|
||||||
|
if attempts is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
attempts = int(attempts)
|
||||||
|
if attempts >= 8:
|
||||||
|
ttl = await redis_client.ttl(key)
|
||||||
|
minutes = (ttl + 59) // 60 if ttl > 0 else 15
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Too many failed login attempts. Try again in ~{minutes} minutes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def increment_failed_attempt(username: str):
|
||||||
|
key = f"login_attempts:{username.lower()}"
|
||||||
|
pipe = await redis_client.pipeline()
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.expire(key, 20 * 60)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_failed_attempts(username: str):
|
||||||
|
key = f"login_attempts:{username.lower()}"
|
||||||
|
await redis_client.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
|
return pwd_context.verify(plain, hashed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def create_jwt_token(data: Dict, expires_delta: timedelta) -> str:
|
||||||
|
to_encode = data.copy()
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_user_scopes(roles: List[str], extra_scopes: List[str]) -> List[str]:
|
||||||
|
scopes: Set[str] = set(extra_scopes)
|
||||||
|
for role in roles:
|
||||||
|
if role in ROLE_SCOPES:
|
||||||
|
scopes.update(ROLE_SCOPES[role])
|
||||||
|
return list(scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_token(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
### CSRF защита
|
||||||
|
|
||||||
|
def generate_csrf_token() -> str:
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
def set_csrf_cookie(response: Response, csrf_token: str):
|
||||||
|
response.set_cookie(
|
||||||
|
key="csrf_token",
|
||||||
|
value=csrf_token,
|
||||||
|
httponly=PROD,
|
||||||
|
secure=PROD,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=3600 * 24
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_csrf(request: Request, user_id_str: str, response: Response):
|
||||||
|
if request.method in ("GET", "HEAD", "OPTIONS", "TRACE"):
|
||||||
|
return
|
||||||
|
|
||||||
|
cookie_csrf = request.cookies.get("csrf_token")
|
||||||
|
|
||||||
|
if not cookie_csrf:
|
||||||
|
raise HTTPException(status_code=403, detail="CSRF token mismatch or missing")
|
||||||
|
|
||||||
|
stored_csrf_token = await redis_client.get(f"csrf:{user_id_str}")
|
||||||
|
|
||||||
|
if stored_csrf_token is None or stored_csrf_token != cookie_csrf:
|
||||||
|
raise HTTPException(status_code=403, detail="CSRF token mismatch or missing")
|
||||||
|
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token)
|
||||||
|
set_csrf_cookie(response, csrf_token)
|
||||||
|
|
||||||
|
|
||||||
|
### Извлечение access token из cookie или Authorization
|
||||||
|
|
||||||
|
def get_access_token(request: Request) -> str | None:
|
||||||
|
token = request.cookies.get("access_token")
|
||||||
|
if token:
|
||||||
|
return token
|
||||||
|
auth = request.headers.get("Authorization")
|
||||||
|
if auth and auth.startswith("Bearer "):
|
||||||
|
return auth.split(" ")[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
### Получение текущего пользователя
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
request: Request,
|
||||||
|
security_scopes: SecurityScopes
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
token = get_access_token(request)
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
|
key = f"access:{token}"
|
||||||
|
user_id_str = await redis_client.get(key)
|
||||||
|
if not user_id_str:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid or expired access token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
if payload.get("sub") != user_id_str:
|
||||||
|
raise HTTPException(status_code=401, detail="Token mismatch")
|
||||||
|
|
||||||
|
user = await users_collection.find_one({"_id": ObjectId(user_id_str)}, {"_id": 0, "hashed_password": 0})
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
|
|
||||||
|
token_scopes: List[str] = payload.get("scopes", [])
|
||||||
|
for scope in security_scopes.scopes:
|
||||||
|
if scope not in token_scopes:
|
||||||
|
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||||
|
|
||||||
|
scopes = get_all_user_scopes(user.get("roles", []), user.get("extra_scopes", []))
|
||||||
|
user['scopes'] = scopes
|
||||||
|
user["id"] = user_id_str
|
||||||
|
return user
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(status_code=401, detail="Access token expired")
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
|
### Эндпоинты
|
||||||
|
|
||||||
|
@app.post("/register", response_model=dict)
|
||||||
|
async def register(
|
||||||
|
request: Request,
|
||||||
|
user: UserRegister,
|
||||||
|
response: Response,
|
||||||
|
):
|
||||||
|
if user.username and await users_collection.find_one({"username": user.username}):
|
||||||
|
raise HTTPException(400, "Имя пользователя занято")
|
||||||
|
if normalize_phone(user.phone) and await users_collection.find_one({"phone": normalize_phone(user.phone)}):
|
||||||
|
raise HTTPException(400, "Телефон уже зарегестрирован")
|
||||||
|
if user.email and await users_collection.find_one({"email": user.email.lower()}):
|
||||||
|
raise HTTPException(400, "Email уже зарегестрирован")
|
||||||
|
|
||||||
|
hashed_password = get_password_hash(user.password)
|
||||||
|
|
||||||
|
allowed_roles = [r for r in user.roles if r not in BLOCK_ROLES]
|
||||||
|
|
||||||
|
doc = {
|
||||||
|
"username": user.username,
|
||||||
|
"phone": normalize_phone(user.phone) if user.phone else None,
|
||||||
|
"email": user.email.lower() if user.email else None,
|
||||||
|
"hashed_password": hashed_password,
|
||||||
|
"roles": allowed_roles,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await users_collection.insert_one(doc)
|
||||||
|
user_id_str = str(result.inserted_id)
|
||||||
|
|
||||||
|
scopes = get_all_user_scopes(allowed_roles, [])
|
||||||
|
|
||||||
|
client_id = user.client_id
|
||||||
|
client_ip = request.client.host
|
||||||
|
|
||||||
|
access_exp_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
refresh_exp_delta = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
access_token = create_jwt_token(
|
||||||
|
{"sub": user_id_str, "scopes": scopes, "client_id": client_id, "ip": client_ip},
|
||||||
|
access_exp_delta
|
||||||
|
)
|
||||||
|
refresh_token = create_jwt_token(
|
||||||
|
{"sub": user_id_str, "type": "refresh", "client_id": client_id, "ip": client_ip},
|
||||||
|
refresh_exp_delta
|
||||||
|
)
|
||||||
|
|
||||||
|
await redis_client.setex(f"access:{access_token}", int(access_exp_delta.total_seconds()), user_id_str)
|
||||||
|
await redis_client.setex(f"refresh:{hash_token(refresh_token)}", int(refresh_exp_delta.total_seconds()), user_id_str)
|
||||||
|
|
||||||
|
session_key = f"user_sessions:{user_id_str}"
|
||||||
|
await redis_client.sadd(session_key, f"access:{access_token}")
|
||||||
|
await redis_client.sadd(session_key, f"refresh:{hash_token(refresh_token)}")
|
||||||
|
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token)
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=access_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=PROD,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=int(access_exp_delta.total_seconds())
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=refresh_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=PROD,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=int(refresh_exp_delta.total_seconds())
|
||||||
|
)
|
||||||
|
set_csrf_cookie(response, csrf_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"user_id": user_id_str,
|
||||||
|
"token_type": "bearer"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/login")
|
||||||
|
async def login(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
login_data: LoginRequest = Body(...),
|
||||||
|
):
|
||||||
|
identifier = login_data.identifier.strip()
|
||||||
|
await check_bruteforce(identifier)
|
||||||
|
|
||||||
|
user = await users_collection.find_one({
|
||||||
|
"$or": [
|
||||||
|
{"username": identifier},
|
||||||
|
{"phone": normalize_phone(identifier)},
|
||||||
|
{"email": identifier.lower()},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
if not user or not verify_password(login_data.password, user["hashed_password"]):
|
||||||
|
await increment_failed_attempt(identifier)
|
||||||
|
raise HTTPException(401, "Incorrect credentials")
|
||||||
|
|
||||||
|
await reset_failed_attempts(identifier)
|
||||||
|
|
||||||
|
user_id_str = str(user["_id"])
|
||||||
|
scopes = get_all_user_scopes(user.get("roles", []), user.get("extra_scopes", []))
|
||||||
|
|
||||||
|
client_ip = request.client.host
|
||||||
|
|
||||||
|
access_exp_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
refresh_exp_delta = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
access_token = create_jwt_token(
|
||||||
|
{"sub": user_id_str, "scopes": scopes, "client_id": login_data.client_id, "ip": client_ip},
|
||||||
|
access_exp_delta
|
||||||
|
)
|
||||||
|
refresh_token = create_jwt_token(
|
||||||
|
{"sub": user_id_str, "type": "refresh", "client_id": login_data.client_id, "ip": client_ip},
|
||||||
|
refresh_exp_delta
|
||||||
|
)
|
||||||
|
|
||||||
|
await redis_client.setex(f"access:{access_token}", int(access_exp_delta.total_seconds()), user_id_str)
|
||||||
|
await redis_client.setex(f"refresh:{hash_token(refresh_token)}", int(refresh_exp_delta.total_seconds()), user_id_str)
|
||||||
|
|
||||||
|
session_key = f"user_sessions:{user_id_str}"
|
||||||
|
await redis_client.sadd(session_key, f"access:{access_token}")
|
||||||
|
await redis_client.sadd(session_key, f"refresh:{hash_token(refresh_token)}")
|
||||||
|
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token)
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=access_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=PROD,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=int(access_exp_delta.total_seconds())
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=refresh_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=int(refresh_exp_delta.total_seconds())
|
||||||
|
)
|
||||||
|
set_csrf_cookie(response, csrf_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"user_id": user_id_str,
|
||||||
|
"token_type": "bearer"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh_token(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
req: RefreshRequest
|
||||||
|
):
|
||||||
|
refresh_token = (
|
||||||
|
request.cookies.get("refresh_token") or
|
||||||
|
get_bearer_token(request) or
|
||||||
|
(req.refresh_token if req else None)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not refresh_token:
|
||||||
|
raise HTTPException(401, "Refresh token required")
|
||||||
|
|
||||||
|
client_id = None
|
||||||
|
|
||||||
|
if req:
|
||||||
|
client_id = req.client_id
|
||||||
|
|
||||||
|
refresh_hash = hash_token(refresh_token)
|
||||||
|
user_id_str = await redis_client.get(f"refresh:{refresh_hash}")
|
||||||
|
|
||||||
|
if not user_id_str:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid or expired refresh token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
print(payload,user_id_str)
|
||||||
|
if payload.get("type") != "refresh":
|
||||||
|
raise HTTPException(status_code=401, detail="Not a refresh token")
|
||||||
|
if payload.get("sub") != user_id_str:
|
||||||
|
raise HTTPException(status_code=401, detail="Token mismatch")
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
raise HTTPException(status_code=401, detail=f"Invalid refresh token: {str(e)}")
|
||||||
|
|
||||||
|
await redis_client.delete(f"refresh:{refresh_hash}")
|
||||||
|
user = await users_collection.find_one({"_id": ObjectId(user_id_str)})
|
||||||
|
|
||||||
|
scopes = get_all_user_scopes(user.get("roles", []), user.get("extra_scopes", []))
|
||||||
|
|
||||||
|
client_ip = request.client.host
|
||||||
|
print(client_ip)
|
||||||
|
|
||||||
|
access_exp = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
refresh_exp = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
new_access = create_jwt_token(
|
||||||
|
{"sub": user_id_str, "scopes": scopes, "client_id": client_id, "ip": client_ip},
|
||||||
|
access_exp
|
||||||
|
)
|
||||||
|
new_refresh = create_jwt_token(
|
||||||
|
{"sub": user_id_str, "type": "refresh", "client_id": client_id, "ip": client_ip},
|
||||||
|
refresh_exp
|
||||||
|
)
|
||||||
|
|
||||||
|
new_refresh_hash = hash_token(new_refresh)
|
||||||
|
await redis_client.setex(f"access:{new_access}", int(access_exp.total_seconds()), user_id_str)
|
||||||
|
await redis_client.setex(f"refresh:{new_refresh_hash}", int(refresh_exp.total_seconds()), user_id_str)
|
||||||
|
|
||||||
|
session_key = f"user_sessions:{user_id_str}"
|
||||||
|
await redis_client.sadd(session_key, f"access:{new_access}")
|
||||||
|
await redis_client.sadd(session_key, f"refresh:{new_refresh_hash}")
|
||||||
|
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
await redis_client.setex(f"csrf:{user_id_str}", 3600 * 24, csrf_token)
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=new_access,
|
||||||
|
httponly=True,
|
||||||
|
secure=True,
|
||||||
|
samesite="strict",
|
||||||
|
max_age=int(access_exp.total_seconds())
|
||||||
|
)
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=new_refresh,
|
||||||
|
httponly=True,
|
||||||
|
secure=True,
|
||||||
|
samesite="strict",
|
||||||
|
max_age=int(refresh_exp.total_seconds())
|
||||||
|
)
|
||||||
|
|
||||||
|
set_csrf_cookie(response, csrf_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": new_access,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"refresh_token": new_refresh
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/logout")
|
||||||
|
async def logout(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
current_user: Dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
await validate_csrf(request)
|
||||||
|
|
||||||
|
user_id_str = current_user["id"]
|
||||||
|
|
||||||
|
# Удаляем все сессии
|
||||||
|
session_key = f"user_sessions:{user_id_str}"
|
||||||
|
tokens = await redis_client.smembers(session_key)
|
||||||
|
for t in tokens:
|
||||||
|
await redis_client.delete(t)
|
||||||
|
await redis_client.delete(session_key)
|
||||||
|
|
||||||
|
# Удаляем csrf
|
||||||
|
await redis_client.delete(f"csrf:{user_id_str}")
|
||||||
|
|
||||||
|
# Удаляем cookie (expires в прошлом)
|
||||||
|
response.delete_cookie("access_token")
|
||||||
|
response.delete_cookie("refresh_token")
|
||||||
|
response.delete_cookie("csrf_token")
|
||||||
|
|
||||||
|
return {"message": "Logged out from all sessions"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/me")
|
||||||
|
async def get_me(
|
||||||
|
current_user: Dict = Security(get_current_user, scopes=["read:items"])
|
||||||
|
):
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/items")
|
||||||
|
async def create_item(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
current_user: Dict = Security(get_current_user, scopes=["create:items"])
|
||||||
|
):
|
||||||
|
await validate_csrf(request, current_user["id"], response)
|
||||||
|
return {"message": f"Item created by {current_user['username']}"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/healthz")
|
||||||
|
async def healthz():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def main(request: Request):
|
||||||
|
cookies = request.cookies
|
||||||
|
headers = request.headers
|
||||||
|
return {"cookies": cookies, "headers": headers}
|
||||||
|
|
||||||
|
@app.get("/form", response_class=HTMLResponse)
|
||||||
|
async def form_(request: Request):
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request=request, name="form.html", context={"id": 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
@app.get("/test", response_class=HTMLResponse)
|
||||||
|
async def read_item(request: Request):
|
||||||
|
news = await get_latest_news()
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request=request, name="index.html", context={"news": news, "posts": [{'_id':ObjectId(), 'title':'Title', 'description':'Description', 'user': {'name':" F I O", 'grade': 'No grade'}, 'media': [{'url':'/static/logo.jpeg'}]},{'_id':ObjectId(), 'media': [{'url':'/static/logo.jpeg'}]}]}
|
||||||
|
)
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class FormOut(BaseModel):
|
||||||
|
id: str
|
||||||
|
title: str
|
||||||
|
topic: Optional[str]
|
||||||
|
short_desc: Optional[str]
|
||||||
|
rules_file_id: Optional[str]
|
||||||
|
poster_file_id: Optional[str]
|
||||||
|
min_paintings: Optional[int]
|
||||||
|
max_paintings: Optional[int]
|
||||||
|
selected_count: Optional[int]
|
||||||
|
|
||||||
|
@app.post("/api", response_model=FormOut)
|
||||||
|
async def receive_form(
|
||||||
|
poster: Optional[UploadFile] = File(None),
|
||||||
|
title: str = Form(...),
|
||||||
|
topic: Optional[str] = Form(None),
|
||||||
|
short_desc: Optional[str] = Form(None),
|
||||||
|
rules: Optional[UploadFile] = File(None),
|
||||||
|
min_paintings: Optional[int] = Form(None),
|
||||||
|
max_paintings: Optional[int] = Form(None),
|
||||||
|
selected_count: Optional[int] = Form(None),
|
||||||
|
):
|
||||||
|
# helper to save file to GridFS and return file id (as str)
|
||||||
|
async def save_to_gridfs(upload: UploadFile):
|
||||||
|
if upload is None:
|
||||||
|
return None
|
||||||
|
filename = upload.filename or str(uuid.uuid4())
|
||||||
|
# read bytes
|
||||||
|
contents = await upload.read()
|
||||||
|
# upload to GridFS
|
||||||
|
file_id = await gridfs.upload_from_stream(filename, contents)
|
||||||
|
return str(file_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
poster_id = await save_to_gridfs(poster)
|
||||||
|
rules_id = await save_to_gridfs(rules)
|
||||||
|
|
||||||
|
doc = {
|
||||||
|
"title": title,
|
||||||
|
"topic": topic,
|
||||||
|
"short_desc": short_desc,
|
||||||
|
"poster_file_id": poster_id,
|
||||||
|
"rules_file_id": rules_id,
|
||||||
|
"min_paintings": int(min_paintings) if min_paintings is not None else None,
|
||||||
|
"max_paintings": int(max_paintings) if max_paintings is not None else None,
|
||||||
|
"selected_count": int(selected_count) if selected_count is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
res = await form_col.insert_one(doc)
|
||||||
|
doc_out = {
|
||||||
|
"id": str(res.inserted_id),
|
||||||
|
**{k: v for k, v in doc.items()}
|
||||||
|
}
|
||||||
|
print(doc_out)
|
||||||
|
return JSONResponse(status_code=201, content=FormOut(**doc_out).dict())
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from bson import ObjectId
|
||||||
|
import mimetypes
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
# helper: stream file from GridFS by id (str)
|
||||||
|
async def stream_gridfs_file(file_id: str):
|
||||||
|
if file_id is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
oid = ObjectId(file_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
# open download stream
|
||||||
|
stream = await gridfs.open_download_stream(oid)
|
||||||
|
# get file metadata
|
||||||
|
filename = stream.filename or "file"
|
||||||
|
content_type = stream.content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||||
|
# async generator to read chunks
|
||||||
|
async def file_iterator(chunk_size: int = 1024 * 64):
|
||||||
|
while True:
|
||||||
|
chunk = await stream.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
await stream.close()
|
||||||
|
return filename, content_type, file_iterator
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from bson import ObjectId
|
||||||
|
import mimetypes
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
def content_disposition_header(filename: str, inline: bool = True) -> str:
|
||||||
|
dispo_type = "inline" if inline else "attachment"
|
||||||
|
# sanitized ASCII fallback
|
||||||
|
try:
|
||||||
|
filename_ascii = filename.encode("latin-1")
|
||||||
|
# if succeeds, return simple header
|
||||||
|
return f'{dispo_type}; filename="{filename}"'
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
# RFC 5987: filename* with UTF-8''percent-encoded
|
||||||
|
filename_quoted = quote(filename, safe='')
|
||||||
|
return f"{dispo_type}; filename*=UTF-8''{filename_quoted}"
|
||||||
|
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
@app.get("/files/{file_id}")
|
||||||
|
async def get_file(file_id: str):
|
||||||
|
# if request asks for random file
|
||||||
|
if file_id == "random":
|
||||||
|
files_col = db_form["fs.files"]
|
||||||
|
total = await files_col.count_documents({})
|
||||||
|
if total == 0:
|
||||||
|
raise HTTPException(status_code=404, detail="No files available")
|
||||||
|
# pick random skip index
|
||||||
|
idx = randint(0, max(0, total - 1))
|
||||||
|
cursor = files_col.find().skip(idx).limit(1)
|
||||||
|
doc = await cursor.to_list(length=1)
|
||||||
|
if not doc:
|
||||||
|
raise HTTPException(status_code=404, detail="No file found")
|
||||||
|
file_id = str(doc[0]["_id"])
|
||||||
|
|
||||||
|
res = await stream_gridfs_file(file_id)
|
||||||
|
if res is None:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
filename, content_type, iterator = res
|
||||||
|
headers = {"Content-Disposition": content_disposition_header(filename, inline=True)}
|
||||||
|
return StreamingResponse(iterator(), media_type=content_type, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/documents/{doc_id}")
|
||||||
|
async def get_document(doc_id: str):
|
||||||
|
try:
|
||||||
|
oid = ObjectId(doc_id)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid document id")
|
||||||
|
doc = await form_col.find_one({"_id": oid})
|
||||||
|
if not doc:
|
||||||
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
# construct download/view urls (adjust host/path if needed)
|
||||||
|
def make_url(fid):
|
||||||
|
return f"/files/{fid}" if fid else None
|
||||||
|
return {
|
||||||
|
"id": str(doc["_id"]),
|
||||||
|
"title": doc.get("title"),
|
||||||
|
"topic": doc.get("topic"),
|
||||||
|
"short_desc": doc.get("short_desc"),
|
||||||
|
"poster_url": make_url(doc.get("poster_file_id")),
|
||||||
|
"rules_url": make_url(doc.get("rules_file_id")),
|
||||||
|
"min_paintings": doc.get("min_paintings"),
|
||||||
|
"max_paintings": doc.get("max_paintings"),
|
||||||
|
"selected_count": doc.get("selected_count"),
|
||||||
|
"task_id": doc.get("task_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/documents/{doc_id}")
|
||||||
|
async def get_document(doc_id: str):
|
||||||
|
try:
|
||||||
|
oid = ObjectId(doc_id)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid document id")
|
||||||
|
doc = await form_col.find_one({"_id": oid})
|
||||||
|
if not doc:
|
||||||
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
# construct download/view urls (adjust host/path if needed)
|
||||||
|
def make_url(fid):
|
||||||
|
return f"/files/{fid}" if fid else None
|
||||||
|
return {
|
||||||
|
"id": str(doc["_id"]),
|
||||||
|
"title": doc.get("title"),
|
||||||
|
"topic": doc.get("topic"),
|
||||||
|
"short_desc": doc.get("short_desc"),
|
||||||
|
"poster_url": make_url(doc.get("poster_file_id")),
|
||||||
|
"rules_url": make_url(doc.get("rules_file_id")),
|
||||||
|
"min_paintings": doc.get("min_paintings"),
|
||||||
|
"max_paintings": doc.get("max_paintings"),
|
||||||
|
"selected_count": doc.get("selected_count"),
|
||||||
|
"task_id": doc.get("task_id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
from fastapi import Query
|
||||||
|
|
||||||
|
@app.get("/files")
|
||||||
|
async def list_gridfs_files(
|
||||||
|
filename: Optional[str] = Query(None, description="Filter by filename (substring)"),
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
):
|
||||||
|
q = {}
|
||||||
|
if filename:
|
||||||
|
q["filename"] = {"$regex": filename, "$options": "i"}
|
||||||
|
# GridFS files collection is <dbname>.fs.files
|
||||||
|
files_col = db_form["fs.files"]
|
||||||
|
cursor = files_col.find(q).sort("uploadDate", -1).skip(skip).limit(limit)
|
||||||
|
items = []
|
||||||
|
async for f in cursor:
|
||||||
|
items.append({
|
||||||
|
"id": str(f["_id"]),
|
||||||
|
"filename": f.get("filename"),
|
||||||
|
"length": f.get("length"),
|
||||||
|
"uploadDate": f.get("uploadDate").isoformat() if f.get("uploadDate") else None,
|
||||||
|
"contentType": f.get("contentType"),
|
||||||
|
"metadata": f.get("metadata"),
|
||||||
|
})
|
||||||
|
total = await files_col.count_documents(q)
|
||||||
|
return {"total": total, "skip": skip, "limit": limit, "items": items}
|
||||||
|
|
||||||
|
|
||||||
|
# news_col = mongo_client["back"]["posts"]
|
||||||
|
# async for i in news_col.find({}).limit(20):
|
||||||
|
# print(i)
|
||||||
Reference in New Issue
Block a user