"""CRUD operations for transcripts."""
from typing import List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from youtube_transcript_api import FetchedTranscript
from lingua_loop.db.models import Segment
from lingua_loop.db.models import Transcript
from lingua_loop.exceptions import TranscriptNotFoundError
from lingua_loop.integrations.youtube.types import SupportedLanguageCodes
from lingua_loop.integrations.youtube.wrapper import fetch_transcript
from lingua_loop.integrations.youtube.wrapper import list_transcripts
from lingua_loop.integrations.youtube.wrapper import (
video_has_transcript_in_language,
)
[docs]
async def read_or_create_transcript_with_segments(
video_id: str, language_code: SupportedLanguageCodes, session: AsyncSession
) -> Transcript:
"""Get or create a transcript with all segments for the given video."""
transcript = await _read_transcript_with_segments(
video_id=video_id, session=session
)
if not transcript:
transcript = await _create_transcript(
video_id=video_id, language_code=language_code, session=session
)
return transcript
async def _read_transcript_with_segments(
video_id: str, session: AsyncSession
) -> Transcript | None:
"""
NOTE: Since this does not check the language code, even if client
sends a request with the incorrect language code, if the transcript
is already cached locally, then they will get a valid transcript response.
"""
result = await session.execute(
select(Transcript)
.options(selectinload(Transcript.segments))
.where(Transcript.video_id == video_id)
)
transcript = result.scalar_one_or_none()
return transcript
async def _create_transcript(
video_id: str, language_code: SupportedLanguageCodes, session: AsyncSession
) -> Transcript:
"""Create a new transcript record with segments from YouTube."""
transcript_list = list_transcripts(video_id=video_id)
has_transcript = video_has_transcript_in_language(
transcript_list=transcript_list, language_code=language_code
)
if not has_transcript:
raise TranscriptNotFoundError(video_id=video_id)
fetched_transcript = fetch_transcript(
video_id=video_id, language_code=language_code
)
is_generated = fetched_transcript.is_generated
transcript = Transcript(
video_id=video_id,
language_code=language_code,
is_generated=is_generated,
)
segments = _get_segments(fetched_transcript=fetched_transcript)
transcript.segments.extend(segments)
session.add(transcript)
await session.commit()
return transcript
def _get_segments(fetched_transcript: FetchedTranscript) -> List[Segment]:
"""Convert a FetchedTranscript to a list of Segment ORM objects."""
segments: List[Segment] = []
snippets = fetched_transcript.snippets
for snippet in snippets:
start = snippet.start
duration = snippet.duration
text = snippet.text
segment = Segment(start=start, duration=duration, text=text)
segments.append(segment)
return segments