Files
life-echo/api/app/agents/memoir/classification_agent.py

78 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ClassificationAgent将内容分类到 8 个章节类别,或判定无价值返回 None。
对应现有逻辑_classify_chapter_category
"""
from __future__ import annotations
from typing import Any, Optional
from app.core.logging import get_logger
from app.agents.memoir.prompts import (
CHAPTER_CATEGORIES,
get_chapter_classification_prompt,
)
logger = get_logger(__name__)
# 5-stage 关键词(用于 LLM 失败时的兜底)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
# 5-stage → 默认 8-category 映射LLM 分类失败时的兜底)
_STAGE_TO_DEFAULT_CATEGORY = {
"childhood": "childhood",
"education": "education",
"career": "career_early",
"family": "family",
"belief": "beliefs",
}
def _detect_stage(text: str, fallback_stage: str) -> str:
"""根据关键词检测消息所属的 5-stage 阶段"""
message = (text or "").lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
class ClassificationAgent:
"""将内容分类到 8 个章节类别之一,或判定无价值返回 None"""
def classify(
self,
text: str,
fallback_stage: str,
llm: Any,
) -> Optional[str]:
"""
分类到 8 个章节类别之一。
若 LLM 判定内容无实质回忆录价值,返回 None。
llm 需支持 .invoke(prompt) 同步调用。
"""
if llm:
try:
prompt = get_chapter_classification_prompt(text)
response = llm.invoke(prompt)
category = (response.content or "").strip().lower()
if category == "none":
logger.info("LLM 判定内容无回忆录价值,跳过: %s...", (text or "")[:80])
return None
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.warning("ClassificationAgent LLM 章节分类失败: %s", e)
stage = _detect_stage(text, fallback_stage)
return _STAGE_TO_DEFAULT_CATEGORY.get(
stage,
_STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood"),
)