"""memory_sources: segment_id for phase1 ingest idempotency Revision ID: 0021_memory_source_segment_id Revises: 0020_refresh_rt_lineage """ from typing import Sequence, Union import sqlalchemy as sa from alembic import op revision: str = "0021_memory_source_segment_id" down_revision: Union[str, None] = "0020_refresh_rt_lineage" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def _column_names(table_name: str) -> set[str]: bind = op.get_bind() inspector = sa.inspect(bind) return {column["name"] for column in inspector.get_columns(table_name)} def _index_names(table_name: str) -> set[str]: bind = op.get_bind() inspector = sa.inspect(bind) return {index["name"] for index in inspector.get_indexes(table_name)} def upgrade() -> None: columns = _column_names("memory_sources") if "segment_id" not in columns: op.add_column( "memory_sources", sa.Column("segment_id", sa.String(), nullable=True), ) indexes = _index_names("memory_sources") if "ix_memory_sources_user_segment_transcript" not in indexes: op.create_index( "ix_memory_sources_user_segment_transcript", "memory_sources", ["user_id", "segment_id"], unique=True, postgresql_where=sa.text( "segment_id IS NOT NULL AND source_type = 'transcript'" ), ) foreign_keys = { fk["name"] for fk in sa.inspect(op.get_bind()).get_foreign_keys("memory_sources") } if "fk_memory_sources_segment_id_segments" not in foreign_keys: op.create_foreign_key( "fk_memory_sources_segment_id_segments", "memory_sources", "segments", ["segment_id"], ["id"], ondelete="SET NULL", ) def downgrade() -> None: foreign_keys = { fk["name"] for fk in sa.inspect(op.get_bind()).get_foreign_keys("memory_sources") } if "fk_memory_sources_segment_id_segments" in foreign_keys: op.drop_constraint( "fk_memory_sources_segment_id_segments", "memory_sources", type_="foreignkey", ) indexes = _index_names("memory_sources") if "ix_memory_sources_user_segment_transcript" in indexes: op.drop_index( "ix_memory_sources_user_segment_transcript", table_name="memory_sources", ) columns = _column_names("memory_sources") if "segment_id" in columns: op.drop_column("memory_sources", "segment_id")