91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
||
"""回填 segments.user_message_id / lineage_json(由既有 conversation_messages 配对)。
|
||
|
||
用法::
|
||
|
||
cd api && uv run python scripts/backfill_segment_dialogue_lineage.py
|
||
cd api && uv run python scripts/backfill_segment_dialogue_lineage.py --limit 500
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.features.auth import models as _auth_models # noqa: F401
|
||
from app.features.conversation import models as _conv_models # noqa: F401
|
||
from app.features.memory import models as _memory_models # noqa: F401
|
||
from app.features.memoir import models as _memoir_models # noqa: F401
|
||
from app.features.payment import models as _payment_models # noqa: F401
|
||
from app.features.story import models as _story_models # noqa: F401
|
||
from app.features.user import models as _user_models # noqa: F401
|
||
|
||
from app.core.db import SessionLocal
|
||
from app.features.conversation.lineage_schemas import DialogueLineage
|
||
from app.features.conversation.models import Conversation, ConversationMessage, Segment
|
||
|
||
|
||
def _first_message_for_segment(
|
||
session: Session, *, segment_id: str, role: str
|
||
) -> ConversationMessage | None:
|
||
stmt = (
|
||
select(ConversationMessage)
|
||
.where(
|
||
ConversationMessage.segment_id == segment_id,
|
||
ConversationMessage.role == role,
|
||
)
|
||
.order_by(ConversationMessage.created_at.asc(), ConversationMessage.id.asc())
|
||
.limit(1)
|
||
)
|
||
return session.execute(stmt).scalar_one_or_none()
|
||
|
||
|
||
def main() -> None:
|
||
p = argparse.ArgumentParser()
|
||
p.add_argument(
|
||
"--limit", type=int, default=0, help="最多处理 segment 条数,0 不限制"
|
||
)
|
||
args = p.parse_args()
|
||
|
||
session = SessionLocal()
|
||
n = 0
|
||
try:
|
||
stmt = (
|
||
select(Segment)
|
||
.join(Conversation, Segment.conversation_id == Conversation.id)
|
||
.where(
|
||
Segment.lineage_json.is_(None),
|
||
Conversation.deleted_at.is_(None),
|
||
)
|
||
.order_by(Segment.created_at.asc())
|
||
)
|
||
if args.limit > 0:
|
||
stmt = stmt.limit(args.limit)
|
||
segments = list(session.execute(stmt).scalars().all())
|
||
for seg in segments:
|
||
hum = _first_message_for_segment(
|
||
session, segment_id=str(seg.id), role="human"
|
||
)
|
||
if not hum:
|
||
continue
|
||
ai = _first_message_for_segment(session, segment_id=str(seg.id), role="ai")
|
||
ln = DialogueLineage.for_single_turn(
|
||
conversation_id=str(seg.conversation_id),
|
||
user_message_id=str(hum.id),
|
||
assistant_message_id=str(ai.id) if ai else None,
|
||
segment_ids=[str(seg.id)],
|
||
)
|
||
seg.user_message_id = str(hum.id)
|
||
seg.lineage_json = ln.model_dump(mode="json")
|
||
n += 1
|
||
session.commit()
|
||
print(f"updated_segments={n} scanned={len(segments)}")
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|