Source code for lingua_loop.db.transcript

"""CRUD operations for transcripts."""

from typing import List

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from youtube_transcript_api import FetchedTranscript

from lingua_loop.db.models import Segment
from lingua_loop.db.models import Transcript
from lingua_loop.exceptions import TranscriptNotFoundError
from lingua_loop.integrations.youtube.types import SupportedLanguageCodes
from lingua_loop.integrations.youtube.wrapper import fetch_transcript
from lingua_loop.integrations.youtube.wrapper import list_transcripts
from lingua_loop.integrations.youtube.wrapper import (
    video_has_transcript_in_language,
)


[docs] async def read_or_create_transcript_with_segments( video_id: str, language_code: SupportedLanguageCodes, session: AsyncSession ) -> Transcript: """Get or create a transcript with all segments for the given video.""" transcript = await _read_transcript_with_segments( video_id=video_id, session=session ) if not transcript: transcript = await _create_transcript( video_id=video_id, language_code=language_code, session=session ) return transcript
async def _read_transcript_with_segments( video_id: str, session: AsyncSession ) -> Transcript | None: """ NOTE: Since this does not check the language code, even if client sends a request with the incorrect language code, if the transcript is already cached locally, then they will get a valid transcript response. """ result = await session.execute( select(Transcript) .options(selectinload(Transcript.segments)) .where(Transcript.video_id == video_id) ) transcript = result.scalar_one_or_none() return transcript async def _create_transcript( video_id: str, language_code: SupportedLanguageCodes, session: AsyncSession ) -> Transcript: """Create a new transcript record with segments from YouTube.""" transcript_list = list_transcripts(video_id=video_id) has_transcript = video_has_transcript_in_language( transcript_list=transcript_list, language_code=language_code ) if not has_transcript: raise TranscriptNotFoundError(video_id=video_id) fetched_transcript = fetch_transcript( video_id=video_id, language_code=language_code ) is_generated = fetched_transcript.is_generated transcript = Transcript( video_id=video_id, language_code=language_code, is_generated=is_generated, ) segments = _get_segments(fetched_transcript=fetched_transcript) transcript.segments.extend(segments) session.add(transcript) await session.commit() return transcript def _get_segments(fetched_transcript: FetchedTranscript) -> List[Segment]: """Convert a FetchedTranscript to a list of Segment ORM objects.""" segments: List[Segment] = [] snippets = fetched_transcript.snippets for snippet in snippets: start = snippet.start duration = snippet.duration text = snippet.text segment = Segment(start=start, duration=duration, text=text) segments.append(segment) return segments