Skip to content
Snippets Groups Projects
Commit 07393e5e authored by Simon Künzel's avatar Simon Künzel
Browse files

filter graph: allow selecting input stream only from video or audio streams

parent 8b59934b
No related branches found
No related tags found
No related merge requests found
Pipeline #7805 passed
Pipeline: backend

#7806

    ...@@ -16,8 +16,27 @@ if TYPE_CHECKING: # pragma: no cover ...@@ -16,8 +16,27 @@ if TYPE_CHECKING: # pragma: no cover
    from videoag_common.objects import Lecture, MediumMetadata, MediumFile from videoag_common.objects import Lecture, MediumMetadata, MediumFile
    class StreamIdContext(JsonSerializableEnum):
    ALL = "all"
    VIDEO = "video"
    AUDIO = "audio"
    def is_stream_in_context(context: StreamIdContext, stream: FFProbeStream):
    match context:
    case StreamIdContext.ALL:
    return True
    case StreamIdContext.VIDEO:
    return isinstance(stream, FFProbeVideoStream)
    case StreamIdContext.AUDIO:
    return isinstance(stream, FFProbeAudioStream)
    case _:
    raise AssertionError(f"Unknown id context: {context}")
    class StreamSpecifier(JsonDataClass): class StreamSpecifier(JsonDataClass):
    stream_id: int stream_id: int
    stream_id_context: StreamIdContext = StreamIdContext.ALL
    graph_id: str graph_id: str
    ...@@ -27,14 +46,10 @@ class InputFileJNode(JGraphNode): ...@@ -27,14 +46,10 @@ class InputFileJNode(JGraphNode):
    def __post_init__(self): def __post_init__(self):
    self._output_ids: set[str] = set() self._output_ids: set[str] = set()
    stream_ids: set[int] = set()
    for specifier in self.streams: for specifier in self.streams:
    if specifier.graph_id in self._output_ids: if specifier.graph_id in self._output_ids:
    raise JsonSerializableInitException(f"Duplicate graph id {specifier.graph_id}") raise JsonSerializableInitException(f"Duplicate graph id {specifier.graph_id}")
    if specifier.stream_id in stream_ids:
    raise JsonSerializableInitException(f"Duplicate stream id {specifier.stream_id}")
    self._output_ids.add(specifier.graph_id) self._output_ids.add(specifier.graph_id)
    stream_ids.add(specifier.stream_id)
    @classmethod @classmethod
    def get_type(cls) -> str: def get_type(cls) -> str:
    ...@@ -80,6 +95,16 @@ class InputFileJNode(JGraphNode): ...@@ -80,6 +95,16 @@ class InputFileJNode(JGraphNode):
    probe.load_data(file) probe.load_data(file)
    probe_streams = probe.all_streams probe_streams = probe.all_streams
    context_stream_mappings: dict[StreamIdContext, list[int]] = {}
    for stream_context in StreamIdContext:
    current_context_mapping = []
    context_stream_mappings[stream_context] = current_context_mapping
    for i, stream in zip(range(len(probe_streams)), probe_streams):
    if is_stream_in_context(stream_context, stream):
    current_context_mapping.append(i)
    continue
    # Some nodes in the pipeline handle video and audio separately, and they need to use "setpts=PTS-STARTPTS" # Some nodes in the pipeline handle video and audio separately, and they need to use "setpts=PTS-STARTPTS"
    # to reset a streams start pts to 0 (e.g. for the overlay filter). However, if such nodes have to deal with # to reset a streams start pts to 0 (e.g. for the overlay filter). However, if such nodes have to deal with
    # audio and video (like the slide node), the video and audio needs to be kept in sync. Actions like resetting # audio and video (like the slide node), the video and audio needs to be kept in sync. Actions like resetting
    ...@@ -98,6 +123,9 @@ class InputFileJNode(JGraphNode): ...@@ -98,6 +123,9 @@ class InputFileJNode(JGraphNode):
    @dataclass @dataclass
    class StreamData: class StreamData:
    stream_id: int
    # list as a stream might be selected multiple times in different contexts
    graph_ids: list[str]
    probe_stream: FFProbeStream probe_stream: FFProbeStream
    type: FStreamType type: FStreamType
    time_base: Fraction time_base: Fraction
    ...@@ -125,11 +153,20 @@ class InputFileJNode(JGraphNode): ...@@ -125,11 +153,20 @@ class InputFileJNode(JGraphNode):
    stream_data_by_id: dict[int, StreamData] = {} stream_data_by_id: dict[int, StreamData] = {}
    for specifier in self.streams: for specifier in self.streams:
    if specifier.stream_id >= len(probe_streams): context_stream_ids = context_stream_mappings[specifier.stream_id_context]
    raise ValueError(f"Invalid Stream ID {specifier.stream_id} for file {file} which has only" if specifier.stream_id >= len(context_stream_ids):
    f" {len(probe_streams)} streams") raise ValueError(f"Invalid Stream {specifier.stream_id} in context {specifier.stream_id_context} for"
    f" file {file}. That context only has {len(context_stream_ids)} streams")
    stream_id = context_stream_ids[specifier.stream_id]
    probe_stream = probe_streams[stream_id]
    logger.info(f"Stream {specifier.stream_id} in context {specifier.stream_id_context} maps to global stream ID {stream_id}")
    probe_stream = probe_streams[specifier.stream_id] if stream_id in stream_data_by_id:
    stream_data = stream_data_by_id[stream_id]
    logger.warning(f"Stream {stream_id} was mapped again to graph id '{specifier.graph_id}' (already mapped to '{"', '".join(stream_data.graph_ids)}')")
    stream_data.graph_ids.append(specifier.graph_id)
    continue
    if isinstance(probe_stream, FFProbeVideoStream): if isinstance(probe_stream, FFProbeVideoStream):
    stream_type = FStreamType.VIDEO stream_type = FStreamType.VIDEO
    ...@@ -137,22 +174,25 @@ class InputFileJNode(JGraphNode): ...@@ -137,22 +174,25 @@ class InputFileJNode(JGraphNode):
    stream_type = FStreamType.AUDIO stream_type = FStreamType.AUDIO
    else: else:
    raise ValueError(f"Unknown stream type {type(probe_stream)} returned by ffprobe for file {file}" raise ValueError(f"Unknown stream type {type(probe_stream)} returned by ffprobe for file {file}"
    f" for stream {specifier.stream_id}") f" for stream {stream_id}")
    time_base = parse_fraction(probe_stream.raw_data["time_base"]) time_base = parse_fraction(probe_stream.raw_data["time_base"])
    if "duration_ts" in probe_stream.raw_data: if "duration_ts" in probe_stream.raw_data:
    duration_ts = int(probe_stream.raw_data["duration_ts"]) duration_ts = int(probe_stream.raw_data["duration_ts"])
    else: else:
    duration_ts = Fraction(probe_stream.duration_ns, int(math.pow(10, 9))) / time_base duration_ts = Fraction(probe_stream.duration_ns, int(math.pow(10, 9))) / time_base
    stream_data = StreamData( stream_data = StreamData(
    stream_id=stream_id,
    graph_ids=[specifier.graph_id],
    probe_stream=probe_stream, probe_stream=probe_stream,
    type=stream_type, type=stream_type,
    time_base=time_base, time_base=time_base,
    start_pts=int(probe_stream.raw_data["start_pts"]), start_pts=int(probe_stream.raw_data["start_pts"]),
    duration_ts=duration_ts duration_ts=duration_ts
    ) )
    stream_data_by_id[specifier.stream_id] = stream_data stream_data_by_id[stream_id] = stream_data
    logger.info(f"Input file {file}, stream {specifier.stream_id} (in graph as {specifier.graph_id}): {stream_data}") logger.info(f"Input file {file}, stream {stream_id} (in graph as {specifier.graph_id}): {stream_data}")
    if latest_start_time is None: if latest_start_time is None:
    latest_start_time = stream_data.start_time_sec latest_start_time = stream_data.start_time_sec
    ...@@ -169,8 +209,7 @@ class InputFileJNode(JGraphNode): ...@@ -169,8 +209,7 @@ class InputFileJNode(JGraphNode):
    stream_fids: dict[int, FID] = {} stream_fids: dict[int, FID] = {}
    for specifier in self.streams: for stream_id, stream_data in stream_data_by_id.items():
    stream_data = stream_data_by_id[specifier.stream_id]
    final_start_pts = round(latest_start_time / stream_data.time_base) final_start_pts = round(latest_start_time / stream_data.time_base)
    final_end_pts = round(earliest_end_time / stream_data.time_base) final_end_pts = round(earliest_end_time / stream_data.time_base)
    ...@@ -180,10 +219,11 @@ class InputFileJNode(JGraphNode): ...@@ -180,10 +219,11 @@ class InputFileJNode(JGraphNode):
    or final_duration_ts != stream_data.duration_ts) or final_duration_ts != stream_data.duration_ts)
    needs_start_pts_reset = needs_trim or final_start_pts != 0 needs_start_pts_reset = needs_trim or final_start_pts != 0
    metadata_fid = FID(stream_data.type, name_hint=f"jid_{specifier.graph_id}") jid_name_hint = f"jid_{"_".join(stream_data.graph_ids)}"
    metadata_fid = FID(stream_data.type, name_hint=jid_name_hint)
    if needs_start_pts_reset: if needs_start_pts_reset:
    input_fid = FID(stream_data.type, name_hint=f"jid_{specifier.graph_id}_pre_aligned") input_fid = FID(stream_data.type, name_hint=f"{jid_name_hint}_pre_aligned")
    else: else:
    input_fid = metadata_fid input_fid = metadata_fid
    ...@@ -193,7 +233,7 @@ class InputFileJNode(JGraphNode): ...@@ -193,7 +233,7 @@ class InputFileJNode(JGraphNode):
    chain = [] chain = []
    if needs_trim: if needs_trim:
    logger.info(f"Trimming stream {specifier.stream_id} to" logger.info(f"Trimming stream {stream_id} to"
    f" start_pts={final_start_pts} ({float(final_start_pts * stream_data.time_base):.8f}s)," f" start_pts={final_start_pts} ({float(final_start_pts * stream_data.time_base):.8f}s),"
    f" end_pts={final_end_pts} ({float(final_end_pts * stream_data.time_base):.8f}s)" f" end_pts={final_end_pts} ({float(final_end_pts * stream_data.time_base):.8f}s)"
    f" (duration {final_duration_ts} or {float(final_duration_ts * stream_data.time_base):.8f}s)") f" (duration {final_duration_ts} or {float(final_duration_ts * stream_data.time_base):.8f}s)")
    ...@@ -202,7 +242,7 @@ class InputFileJNode(JGraphNode): ...@@ -202,7 +242,7 @@ class InputFileJNode(JGraphNode):
    "end_pts": final_end_pts "end_pts": final_end_pts
    })) }))
    logger.info(f"Resetting STARTPTS for stream {specifier.stream_id} to 0") logger.info(f"Resetting STARTPTS for stream {stream_id} to 0")
    chain.append(ChainFFilter(f"{filter_prefix}setpts", { chain.append(ChainFFilter(f"{filter_prefix}setpts", {
    "expr": "PTS-STARTPTS" "expr": "PTS-STARTPTS"
    })) }))
    ...@@ -226,8 +266,9 @@ class InputFileJNode(JGraphNode): ...@@ -226,8 +266,9 @@ class InputFileJNode(JGraphNode):
    else: else:
    raise AssertionError("Unknown type") raise AssertionError("Unknown type")
    stream_fids[specifier.stream_id] = input_fid stream_fids[stream_id] = input_fid
    context.add_stream_metadata(specifier.graph_id, metadata) for graph_id in stream_data.graph_ids:
    context.add_stream_metadata(graph_id, metadata)
    context.add_input_file(file, stream_fids) context.add_input_file(file, stream_fids)
    ......
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Please register or to comment