87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""回填 segments.user_message_id / lineage_json(由既有 conversation_messages 配对)。
|
|||
|
|
|
|||
|
|
用法::
|
|||
|
|
|
|||
|
|
cd api && uv run python scripts/backfill_segment_dialogue_lineage.py
|
|||
|
|
cd api && uv run python scripts/backfill_segment_dialogue_lineage.py --limit 500
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.features.auth import models as _auth_models # noqa: F401
|
|||
|
|
from app.features.conversation import models as _conv_models # noqa: F401
|
|||
|
|
from app.features.memory import models as _memory_models # noqa: F401
|
|||
|
|
from app.features.memoir import models as _memoir_models # noqa: F401
|
|||
|
|
from app.features.payment import models as _payment_models # noqa: F401
|
|||
|
|
from app.features.story import models as _story_models # noqa: F401
|
|||
|
|
from app.features.user import models as _user_models # noqa: F401
|
|||
|
|
|
|||
|
|
from app.core.db import SessionLocal
|
|||
|
|
from app.features.conversation.lineage_schemas import DialogueLineage
|
|||
|
|
from app.features.conversation.models import Conversation, ConversationMessage, Segment
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _first_message_for_segment(
|
|||
|
|
session: Session, *, segment_id: str, role: str
|
|||
|
|
) -> ConversationMessage | None:
|
|||
|
|
stmt = (
|
|||
|
|
select(ConversationMessage)
|
|||
|
|
.where(
|
|||
|
|
ConversationMessage.segment_id == segment_id,
|
|||
|
|
ConversationMessage.role == role,
|
|||
|
|
)
|
|||
|
|
.order_by(ConversationMessage.created_at.asc(), ConversationMessage.id.asc())
|
|||
|
|
.limit(1)
|
|||
|
|
)
|
|||
|
|
return session.execute(stmt).scalar_one_or_none()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main() -> None:
|
|||
|
|
p = argparse.ArgumentParser()
|
|||
|
|
p.add_argument("--limit", type=int, default=0, help="最多处理 segment 条数,0 不限制")
|
|||
|
|
args = p.parse_args()
|
|||
|
|
|
|||
|
|
session = SessionLocal()
|
|||
|
|
n = 0
|
|||
|
|
try:
|
|||
|
|
stmt = (
|
|||
|
|
select(Segment)
|
|||
|
|
.join(Conversation, Segment.conversation_id == Conversation.id)
|
|||
|
|
.where(
|
|||
|
|
Segment.lineage_json.is_(None),
|
|||
|
|
Conversation.deleted_at.is_(None),
|
|||
|
|
)
|
|||
|
|
.order_by(Segment.created_at.asc())
|
|||
|
|
)
|
|||
|
|
if args.limit > 0:
|
|||
|
|
stmt = stmt.limit(args.limit)
|
|||
|
|
segments = list(session.execute(stmt).scalars().all())
|
|||
|
|
for seg in segments:
|
|||
|
|
hum = _first_message_for_segment(session, segment_id=str(seg.id), role="human")
|
|||
|
|
if not hum:
|
|||
|
|
continue
|
|||
|
|
ai = _first_message_for_segment(session, segment_id=str(seg.id), role="ai")
|
|||
|
|
ln = DialogueLineage.for_single_turn(
|
|||
|
|
conversation_id=str(seg.conversation_id),
|
|||
|
|
user_message_id=str(hum.id),
|
|||
|
|
assistant_message_id=str(ai.id) if ai else None,
|
|||
|
|
segment_ids=[str(seg.id)],
|
|||
|
|
)
|
|||
|
|
seg.user_message_id = str(hum.id)
|
|||
|
|
seg.lineage_json = ln.model_dump(mode="json")
|
|||
|
|
n += 1
|
|||
|
|
session.commit()
|
|||
|
|
print(f"updated_segments={n} scanned={len(segments)}")
|
|||
|
|
finally:
|
|||
|
|
session.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|