50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
"""用户资料收集:缺失字段检测、提取与应用"""
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.features.user.models import User
|
|
|
|
|
|
def get_missing_profile_fields(user: User) -> list:
|
|
"""检查用户缺失的资料字段"""
|
|
from app.agents.chat.prompts_profile import get_missing_profile_fields as _get_missing
|
|
return _get_missing(
|
|
birth_year=user.birth_year,
|
|
birth_place=user.birth_place,
|
|
grew_up_place=user.grew_up_place,
|
|
occupation=user.occupation,
|
|
)
|
|
|
|
|
|
def get_filled_profile_fields(user: User) -> dict:
|
|
"""获取用户已有的资料字段(中文展示)"""
|
|
filled = {}
|
|
if user.birth_year:
|
|
filled["birth_year"] = str(user.birth_year)
|
|
if user.birth_place:
|
|
filled["birth_place"] = user.birth_place
|
|
if user.grew_up_place:
|
|
filled["grew_up_place"] = user.grew_up_place
|
|
if user.occupation:
|
|
filled["occupation"] = user.occupation
|
|
return filled
|
|
|
|
|
|
async def apply_extracted_profile(user: User, extracted: dict, db: AsyncSession):
|
|
"""将提取到的资料信息保存到用户模型"""
|
|
changed = False
|
|
if "birth_year" in extracted and not user.birth_year:
|
|
user.birth_year = extracted["birth_year"]
|
|
changed = True
|
|
if "birth_place" in extracted and not user.birth_place:
|
|
user.birth_place = extracted["birth_place"]
|
|
changed = True
|
|
if "grew_up_place" in extracted and not user.grew_up_place:
|
|
user.grew_up_place = extracted["grew_up_place"]
|
|
changed = True
|
|
if "occupation" in extracted and not user.occupation:
|
|
user.occupation = extracted["occupation"]
|
|
changed = True
|
|
if changed:
|
|
await db.commit()
|
|
await db.refresh(user)
|