91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.features.auth.models import RefreshToken, SmsVerificationCode
|
|
from app.features.user.models import User
|
|
|
|
|
|
async def get_user_by_phone(phone: str, db: AsyncSession) -> User | None:
|
|
stmt = select(User).where(User.phone == phone)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_user_by_id(user_id: str, db: AsyncSession) -> User | None:
|
|
return await db.get(User, user_id)
|
|
|
|
|
|
async def get_user_by_email(email: str, db: AsyncSession) -> User | None:
|
|
stmt = select(User).where(User.email == email)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_refresh_token_by_token(
|
|
token_str: str, db: AsyncSession
|
|
) -> RefreshToken | None:
|
|
stmt = select(RefreshToken).where(RefreshToken.token == token_str)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_active_tokens_for_user(
|
|
user_id: str, db: AsyncSession
|
|
) -> list[RefreshToken]:
|
|
stmt = select(RefreshToken).where(
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.is_revoked == False, # noqa: E712
|
|
)
|
|
result = await db.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
async def create_user(user: User, db: AsyncSession) -> None:
|
|
db.add(user)
|
|
|
|
|
|
async def create_refresh_token(token: RefreshToken, db: AsyncSession) -> None:
|
|
db.add(token)
|
|
|
|
|
|
# ── SMS verification code ─────────────────────────────────────
|
|
|
|
|
|
async def create_verification_code(
|
|
record: SmsVerificationCode, db: AsyncSession
|
|
) -> None:
|
|
db.add(record)
|
|
|
|
|
|
async def get_recent_code_for_rate_limit(
|
|
phone: str, db: AsyncSession
|
|
) -> SmsVerificationCode | None:
|
|
"""Latest verification code record for the phone (for rate limit check)."""
|
|
stmt = (
|
|
select(SmsVerificationCode)
|
|
.where(SmsVerificationCode.phone == phone)
|
|
.order_by(SmsVerificationCode.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_latest_unused_code(
|
|
phone: str, purpose: str, db: AsyncSession
|
|
) -> SmsVerificationCode | None:
|
|
"""Latest unused, non-expired verification code for phone+purpose."""
|
|
stmt = (
|
|
select(SmsVerificationCode)
|
|
.where(
|
|
SmsVerificationCode.phone == phone,
|
|
SmsVerificationCode.purpose == purpose,
|
|
SmsVerificationCode.is_used.is_(False),
|
|
SmsVerificationCode.is_expired.is_(False),
|
|
)
|
|
.order_by(SmsVerificationCode.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|