From 87923800e7fb6a75194dd63fb9cf135b6140e1b3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de>
Date: Sat, 3 May 2025 16:52:23 +0200
Subject: [PATCH] #45 Fix duration_ts not always available in ffprobe

---
 .../src/videoag_common/ffmpeg/ffprobe.py      | 17 ++++++---
 .../media_process/jnode/file_jnode.py         | 38 +++++++++++--------
 .../videoag_common/objects/medium_metadata.py |  7 ++--
 3 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/common_py/src/videoag_common/ffmpeg/ffprobe.py b/common_py/src/videoag_common/ffmpeg/ffprobe.py
index d2b87dc..aab86f6 100644
--- a/common_py/src/videoag_common/ffmpeg/ffprobe.py
+++ b/common_py/src/videoag_common/ffmpeg/ffprobe.py
@@ -8,6 +8,7 @@ from pathlib import Path
 
 @dataclass
 class FFProbeStream:
+    raw_data: dict
     codec: str
 
 
@@ -17,14 +18,14 @@ class FFProbeVideoStream(FFProbeStream):
     height: int
     frame_rate_numerator: int
     frame_rate_denominator: int
-    duration_sec: int
+    duration_ns: int
 
 
 @dataclass
 class FFProbeAudioStream(FFProbeStream):
     sample_rate: int
     channel_count: int
-    duration_sec: int
+    duration_ns: int
 
 
 @dataclass
@@ -47,10 +48,10 @@ def _parse_duration_string_human(val: str) -> int:
     )
 
 
-def _get_stream_duration_sec(stream: dict):
+def _get_stream_duration_ns(stream: dict):
     duration_string = stream.get("duration")
     if duration_string is not None:
-        return math.ceil(float(stream["duration"]))
+        return math.ceil(float(stream["duration"]) * math.pow(10, 9))
     tags = stream.get("tags")
     if tags is not None:
         duration_string = tags.get("DURATION")
@@ -91,6 +92,7 @@ class FFProbe:
                     or file_name.split(".")[-1] in ["png", "jpg", "jpeg"]):
                 # .jpg is actually recognized as a 'Motion JPEG' with 1 frame
                 probe_stream = FFProbeImageStream(
+                    raw_data=stream,
                     codec=stream["codec_name"],
                     width=int(stream["width"]),
                     height=int(stream["height"]),
@@ -99,24 +101,27 @@ class FFProbe:
             elif stream["codec_type"] == "video":
                 frame_rate_num, frame_rate_den = tuple(stream["avg_frame_rate"].split("/"))
                 probe_stream = FFProbeVideoStream(
+                    raw_data=stream,
                     codec=stream["codec_name"],
                     width=int(stream["width"]),
                     height=int(stream["height"]),
                     frame_rate_numerator=int(frame_rate_num),
                     frame_rate_denominator=int(frame_rate_den),
-                    duration_sec=_get_stream_duration_sec(stream),
+                    duration_ns=_get_stream_duration_ns(stream),
                 )
                 self.video_streams.append(probe_stream)
             elif stream["codec_type"] == "audio":
                 probe_stream = FFProbeAudioStream(
+                    raw_data=stream,
                     codec=stream["codec_name"],
                     sample_rate=int(stream["sample_rate"]),
                     channel_count=int(stream["channels"]),
-                    duration_sec=_get_stream_duration_sec(stream),
+                    duration_ns=_get_stream_duration_ns(stream),
                 )
                 self.audio_streams.append(probe_stream)
             else:
                 probe_stream = FFProbeStream(
+                    raw_data=stream,
                     codec=stream.get("codec_name")
                 )
             self.all_streams.append(probe_stream)
diff --git a/common_py/src/videoag_common/media_process/jnode/file_jnode.py b/common_py/src/videoag_common/media_process/jnode/file_jnode.py
index 3ef28ac..b59f764 100644
--- a/common_py/src/videoag_common/media_process/jnode/file_jnode.py
+++ b/common_py/src/videoag_common/media_process/jnode/file_jnode.py
@@ -1,4 +1,5 @@
 import logging
+import math
 from dataclasses import dataclass
 from fractions import Fraction
 from typing import TYPE_CHECKING
@@ -78,7 +79,7 @@ class InputFileJNode(JGraphNode):
         file = context.data_dir.joinpath(job_data.get_string("file_path", max_length=1000)).resolve()
         probe = FFProbe()
         probe.load_data(file)
-        probe_streams = probe.raw_data["streams"]
+        probe_streams = probe.all_streams
         
         # Some nodes in the pipeline handle video and audio separately, and they need to use "setpts=PTS-STARTPTS"
         # to reset a streams start pts to 0 (e.g. for the overlay filter). However, if such nodes have to deal with
@@ -98,6 +99,7 @@ class InputFileJNode(JGraphNode):
         
         @dataclass
         class StreamData:
+            probe_stream: FFProbeStream
             type: FStreamType
             time_base: Fraction
             start_pts: int
@@ -130,19 +132,25 @@ class InputFileJNode(JGraphNode):
             
             probe_stream = probe_streams[specifier.stream_id]
             
-            type = {
-                "video": FStreamType.VIDEO,
-                "audio": FStreamType.AUDIO
-            }.get(probe_stream["codec_type"])
-            if type is None:
-                raise ValueError(f"Unknown codec type {probe_stream["codec_type"]} returned by ffprobe for file {file}"
+            if isinstance(probe_stream, FFProbeVideoStream):
+                stream_type = FStreamType.VIDEO
+            elif isinstance(probe_stream, FFProbeAudioStream):
+                stream_type = FStreamType.AUDIO
+            else:
+                raise ValueError(f"Unknown stream type {type(probe_stream)} returned by ffprobe for file {file}"
                                  f" for stream {specifier.stream_id}")
             
+            time_base = parse_fraction(probe_stream.raw_data["time_base"])
+            if "duration_ts" in probe_stream.raw_data:
+                duration_ts = int(probe_stream.raw_data["duration_ts"])
+            else:
+                duration_ts = Fraction(probe_stream.duration_ns, int(math.pow(10, 9))) / time_base
             stream_data = StreamData(
-                type=type,
-                time_base=parse_fraction(probe_stream["time_base"]),
-                start_pts=int(probe_stream["start_pts"]),
-                duration_ts=int(probe_stream["duration_ts"])
+                probe_stream=probe_stream,
+                type=stream_type,
+                time_base=time_base,
+                start_pts=int(probe_stream.raw_data["start_pts"]),
+                duration_ts=duration_ts
             )
             stream_data_by_id[specifier.stream_id] = stream_data
             logger.info(f"Input file {file}, stream {specifier.stream_id} (in graph as {specifier.graph_id}): {stream_data}")
@@ -164,7 +172,6 @@ class InputFileJNode(JGraphNode):
         
         for specifier in self.streams:
             stream_data = stream_data_by_id[specifier.stream_id]
-            probe_stream = probe_streams[specifier.stream_id]
             
             final_start_pts = round(latest_start_time / stream_data.time_base)
             final_end_pts = round(earliest_end_time / stream_data.time_base)
@@ -205,11 +212,12 @@ class InputFileJNode(JGraphNode):
             
             duration_us = round((final_duration_ts * stream_data.time_base) * 1_000_000)
             if stream_data.type == FStreamType.VIDEO:
+                assert isinstance(stream_data.probe_stream, FFProbeVideoStream)
                 metadata = FVideoStreamMetadata(
                     metadata_fid,
                     duration_us=duration_us,
-                    width=int(probe_stream["width"]),
-                    height=int(probe_stream["height"])
+                    width=stream_data.probe_stream.width,
+                    height=stream_data.probe_stream.height,
                 )
             elif stream_data.type == FStreamType.AUDIO:
                 metadata = FAudioStreamMetadata(
@@ -326,7 +334,7 @@ class OutputFileJNode(JGraphNode):
             if output_stream.framerate is not None:
                 if not isinstance(metadata, FVideoStreamMetadata):
                     raise ValueError(f"Cannot set output framerate for non-video stream {output_stream.graph_id}")
-                out_args["r"] = output_stream.framerate#
+                out_args["r"] = output_stream.framerate
             
             if output_stream.crf is not None:
                 if not isinstance(metadata, FVideoStreamMetadata):
diff --git a/common_py/src/videoag_common/objects/medium_metadata.py b/common_py/src/videoag_common/objects/medium_metadata.py
index 947bcbd..af0318c 100644
--- a/common_py/src/videoag_common/objects/medium_metadata.py
+++ b/common_py/src/videoag_common/objects/medium_metadata.py
@@ -1,3 +1,4 @@
+import math
 
 from videoag_common.miscellaneous import *
 from videoag_common.database import *
@@ -98,14 +99,14 @@ class PlainVideoMediumMetadata(MediumMetadata, SingleVideoMetadata, SingleAudioM
             horizontal_resolution=video.width,
             video_frame_rate_numerator=video.frame_rate_numerator,
             video_frame_rate_denominator=video.frame_rate_denominator,
-            duration_sec=video.duration_sec,
+            duration_sec=math.ceil(video.duration_ns / math.pow(10, 9)),
         )
         
         if len(ffprobe.audio_streams) > 0:
             audio = ffprobe.audio_streams[0]
             medium.audio_sample_rate = audio.sample_rate
             medium.audio_channel_count = audio.channel_count
-            medium.duration_sec = max(video.duration_sec, audio.duration_sec)
+            medium.duration_sec = math.ceil(max(video.duration_ns, audio.duration_ns) / math.pow(10, 9))
         
         return medium
     
@@ -134,7 +135,7 @@ class PlainAudioMediumMetadata(MediumMetadata, SingleAudioMetadata, DurationMeta
         return PlainAudioMediumMetadata(
             audio_sample_rate=audio.sample_rate,
             audio_channel_count=audio.channel_count,
-            duration_sec=audio.duration_sec,
+            duration_sec=math.ceil(audio.duration_ns / math.pow(10, 9)),
         )
     
     def can_include_in_player(self):
-- 
GitLab