diff --git a/common_py/src/videoag_common/media_process/jnode/file_jnode.py b/common_py/src/videoag_common/media_process/jnode/file_jnode.py index 4771a2e918681329aa80cfe3a73246f0536d51ce..7299cb1b5023221f121bae81827914278d3e1168 100644 --- a/common_py/src/videoag_common/media_process/jnode/file_jnode.py +++ b/common_py/src/videoag_common/media_process/jnode/file_jnode.py @@ -16,8 +16,27 @@ if TYPE_CHECKING: # pragma: no cover from videoag_common.objects import Lecture, MediumMetadata, MediumFile +class StreamIdContext(JsonSerializableEnum): + ALL = "all" + VIDEO = "video" + AUDIO = "audio" + + +def is_stream_in_context(context: StreamIdContext, stream: FFProbeStream): + match context: + case StreamIdContext.ALL: + return True + case StreamIdContext.VIDEO: + return isinstance(stream, FFProbeVideoStream) + case StreamIdContext.AUDIO: + return isinstance(stream, FFProbeAudioStream) + case _: + raise AssertionError(f"Unknown id context: {context}") + + class StreamSpecifier(JsonDataClass): stream_id: int + stream_id_context: StreamIdContext = StreamIdContext.ALL graph_id: str @@ -27,14 +46,10 @@ class InputFileJNode(JGraphNode): def __post_init__(self): self._output_ids: set[str] = set() - stream_ids: set[int] = set() for specifier in self.streams: if specifier.graph_id in self._output_ids: raise JsonSerializableInitException(f"Duplicate graph id {specifier.graph_id}") - if specifier.stream_id in stream_ids: - raise JsonSerializableInitException(f"Duplicate stream id {specifier.stream_id}") self._output_ids.add(specifier.graph_id) - stream_ids.add(specifier.stream_id) @classmethod def get_type(cls) -> str: @@ -80,6 +95,16 @@ class InputFileJNode(JGraphNode): probe.load_data(file) probe_streams = probe.all_streams + context_stream_mappings: dict[StreamIdContext, list[int]] = {} + for stream_context in StreamIdContext: + current_context_mapping = [] + context_stream_mappings[stream_context] = current_context_mapping + + for i, stream in zip(range(len(probe_streams)), probe_streams): + if is_stream_in_context(stream_context, stream): + current_context_mapping.append(i) + continue + # Some nodes in the pipeline handle video and audio separately, and they need to use "setpts=PTS-STARTPTS" # to reset a streams start pts to 0 (e.g. for the overlay filter). However, if such nodes have to deal with # audio and video (like the slide node), the video and audio needs to be kept in sync. Actions like resetting @@ -98,6 +123,9 @@ class InputFileJNode(JGraphNode): @dataclass class StreamData: + stream_id: int + # list as a stream might be selected multiple times in different contexts + graph_ids: list[str] probe_stream: FFProbeStream type: FStreamType time_base: Fraction @@ -125,11 +153,20 @@ class InputFileJNode(JGraphNode): stream_data_by_id: dict[int, StreamData] = {} for specifier in self.streams: - if specifier.stream_id >= len(probe_streams): - raise ValueError(f"Invalid Stream ID {specifier.stream_id} for file {file} which has only" - f" {len(probe_streams)} streams") + context_stream_ids = context_stream_mappings[specifier.stream_id_context] + if specifier.stream_id >= len(context_stream_ids): + raise ValueError(f"Invalid Stream {specifier.stream_id} in context {specifier.stream_id_context} for" + f" file {file}. That context only has {len(context_stream_ids)} streams") + stream_id = context_stream_ids[specifier.stream_id] + probe_stream = probe_streams[stream_id] + + logger.info(f"Stream {specifier.stream_id} in context {specifier.stream_id_context} maps to global stream ID {stream_id}") - probe_stream = probe_streams[specifier.stream_id] + if stream_id in stream_data_by_id: + stream_data = stream_data_by_id[stream_id] + logger.warning(f"Stream {stream_id} was mapped again to graph id '{specifier.graph_id}' (already mapped to '{"', '".join(stream_data.graph_ids)}')") + stream_data.graph_ids.append(specifier.graph_id) + continue if isinstance(probe_stream, FFProbeVideoStream): stream_type = FStreamType.VIDEO @@ -137,22 +174,25 @@ class InputFileJNode(JGraphNode): stream_type = FStreamType.AUDIO else: raise ValueError(f"Unknown stream type {type(probe_stream)} returned by ffprobe for file {file}" - f" for stream {specifier.stream_id}") + f" for stream {stream_id}") time_base = parse_fraction(probe_stream.raw_data["time_base"]) if "duration_ts" in probe_stream.raw_data: duration_ts = int(probe_stream.raw_data["duration_ts"]) else: duration_ts = Fraction(probe_stream.duration_ns, int(math.pow(10, 9))) / time_base + stream_data = StreamData( + stream_id=stream_id, + graph_ids=[specifier.graph_id], probe_stream=probe_stream, type=stream_type, time_base=time_base, start_pts=int(probe_stream.raw_data["start_pts"]), duration_ts=duration_ts ) - stream_data_by_id[specifier.stream_id] = stream_data - logger.info(f"Input file {file}, stream {specifier.stream_id} (in graph as {specifier.graph_id}): {stream_data}") + stream_data_by_id[stream_id] = stream_data + logger.info(f"Input file {file}, stream {stream_id} (in graph as {specifier.graph_id}): {stream_data}") if latest_start_time is None: latest_start_time = stream_data.start_time_sec @@ -169,8 +209,7 @@ class InputFileJNode(JGraphNode): stream_fids: dict[int, FID] = {} - for specifier in self.streams: - stream_data = stream_data_by_id[specifier.stream_id] + for stream_id, stream_data in stream_data_by_id.items(): final_start_pts = round(latest_start_time / stream_data.time_base) final_end_pts = round(earliest_end_time / stream_data.time_base) @@ -180,10 +219,11 @@ class InputFileJNode(JGraphNode): or final_duration_ts != stream_data.duration_ts) needs_start_pts_reset = needs_trim or final_start_pts != 0 - metadata_fid = FID(stream_data.type, name_hint=f"jid_{specifier.graph_id}") + jid_name_hint = f"jid_{"_".join(stream_data.graph_ids)}" + metadata_fid = FID(stream_data.type, name_hint=jid_name_hint) if needs_start_pts_reset: - input_fid = FID(stream_data.type, name_hint=f"jid_{specifier.graph_id}_pre_aligned") + input_fid = FID(stream_data.type, name_hint=f"{jid_name_hint}_pre_aligned") else: input_fid = metadata_fid @@ -193,7 +233,7 @@ class InputFileJNode(JGraphNode): chain = [] if needs_trim: - logger.info(f"Trimming stream {specifier.stream_id} to" + logger.info(f"Trimming stream {stream_id} to" f" start_pts={final_start_pts} ({float(final_start_pts * stream_data.time_base):.8f}s)," f" end_pts={final_end_pts} ({float(final_end_pts * stream_data.time_base):.8f}s)" f" (duration {final_duration_ts} or {float(final_duration_ts * stream_data.time_base):.8f}s)") @@ -202,7 +242,7 @@ class InputFileJNode(JGraphNode): "end_pts": final_end_pts })) - logger.info(f"Resetting STARTPTS for stream {specifier.stream_id} to 0") + logger.info(f"Resetting STARTPTS for stream {stream_id} to 0") chain.append(ChainFFilter(f"{filter_prefix}setpts", { "expr": "PTS-STARTPTS" })) @@ -226,8 +266,9 @@ class InputFileJNode(JGraphNode): else: raise AssertionError("Unknown type") - stream_fids[specifier.stream_id] = input_fid - context.add_stream_metadata(specifier.graph_id, metadata) + stream_fids[stream_id] = input_fid + for graph_id in stream_data.graph_ids: + context.add_stream_metadata(graph_id, metadata) context.add_input_file(file, stream_fids)