Files
life-echo/api/scripts/backfill_segment_dialogue_lineage.py

87 lines
3.0 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""回填 segments.user_message_id / lineage_json由既有 conversation_messages 配对)。
用法::
cd api && uv run python scripts/backfill_segment_dialogue_lineage.py
cd api && uv run python scripts/backfill_segment_dialogue_lineage.py --limit 500
"""
from __future__ import annotations
import argparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.features.auth import models as _auth_models # noqa: F401
from app.features.conversation import models as _conv_models # noqa: F401
from app.features.memory import models as _memory_models # noqa: F401
from app.features.memoir import models as _memoir_models # noqa: F401
from app.features.payment import models as _payment_models # noqa: F401
from app.features.story import models as _story_models # noqa: F401
from app.features.user import models as _user_models # noqa: F401
from app.core.db import SessionLocal
from app.features.conversation.lineage_schemas import DialogueLineage
from app.features.conversation.models import Conversation, ConversationMessage, Segment
def _first_message_for_segment(
session: Session, *, segment_id: str, role: str
) -> ConversationMessage | None:
stmt = (
select(ConversationMessage)
.where(
ConversationMessage.segment_id == segment_id,
ConversationMessage.role == role,
)
.order_by(ConversationMessage.created_at.asc(), ConversationMessage.id.asc())
.limit(1)
)
return session.execute(stmt).scalar_one_or_none()
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--limit", type=int, default=0, help="最多处理 segment 条数0 不限制")
args = p.parse_args()
session = SessionLocal()
n = 0
try:
stmt = (
select(Segment)
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(
Segment.lineage_json.is_(None),
Conversation.deleted_at.is_(None),
)
.order_by(Segment.created_at.asc())
)
if args.limit > 0:
stmt = stmt.limit(args.limit)
segments = list(session.execute(stmt).scalars().all())
for seg in segments:
hum = _first_message_for_segment(session, segment_id=str(seg.id), role="human")
if not hum:
continue
ai = _first_message_for_segment(session, segment_id=str(seg.id), role="ai")
ln = DialogueLineage.for_single_turn(
conversation_id=str(seg.conversation_id),
user_message_id=str(hum.id),
assistant_message_id=str(ai.id) if ai else None,
segment_ids=[str(seg.id)],
)
seg.user_message_id = str(hum.id)
seg.lineage_json = ln.model_dump(mode="json")
n += 1
session.commit()
print(f"updated_segments={n} scanned={len(segments)}")
finally:
session.close()
if __name__ == "__main__":
main()