diff --git a/common_py/src/videoag_common/media_process/jnode/file_jnode.py b/common_py/src/videoag_common/media_process/jnode/file_jnode.py
index 4771a2e918681329aa80cfe3a73246f0536d51ce..7299cb1b5023221f121bae81827914278d3e1168 100644
--- a/common_py/src/videoag_common/media_process/jnode/file_jnode.py
+++ b/common_py/src/videoag_common/media_process/jnode/file_jnode.py
@@ -16,8 +16,27 @@ if TYPE_CHECKING:  # pragma: no cover
     from videoag_common.objects import Lecture, MediumMetadata, MediumFile
 
 
+class StreamIdContext(JsonSerializableEnum):
+    ALL = "all"
+    VIDEO = "video"
+    AUDIO = "audio"
+
+
+def is_stream_in_context(context: StreamIdContext, stream: FFProbeStream):
+    match context:
+        case StreamIdContext.ALL:
+            return True
+        case StreamIdContext.VIDEO:
+            return isinstance(stream, FFProbeVideoStream)
+        case StreamIdContext.AUDIO:
+            return isinstance(stream, FFProbeAudioStream)
+        case _:
+            raise AssertionError(f"Unknown id context: {context}")
+
+
 class StreamSpecifier(JsonDataClass):
     stream_id: int
+    stream_id_context: StreamIdContext = StreamIdContext.ALL
     graph_id: str
 
 
@@ -27,14 +46,10 @@ class InputFileJNode(JGraphNode):
     
     def __post_init__(self):
         self._output_ids: set[str] = set()
-        stream_ids: set[int] = set()
         for specifier in self.streams:
             if specifier.graph_id in self._output_ids:
                 raise JsonSerializableInitException(f"Duplicate graph id {specifier.graph_id}")
-            if specifier.stream_id in stream_ids:
-                raise JsonSerializableInitException(f"Duplicate stream id {specifier.stream_id}")
             self._output_ids.add(specifier.graph_id)
-            stream_ids.add(specifier.stream_id)
     
     @classmethod
     def get_type(cls) -> str:
@@ -80,6 +95,16 @@ class InputFileJNode(JGraphNode):
         probe.load_data(file)
         probe_streams = probe.all_streams
         
+        context_stream_mappings: dict[StreamIdContext, list[int]] = {}
+        for stream_context in StreamIdContext:
+            current_context_mapping = []
+            context_stream_mappings[stream_context] = current_context_mapping
+            
+            for i, stream in zip(range(len(probe_streams)), probe_streams):
+                if is_stream_in_context(stream_context, stream):
+                    current_context_mapping.append(i)
+                    continue
+        
         # Some nodes in the pipeline handle video and audio separately, and they need to use "setpts=PTS-STARTPTS"
         # to reset a streams start pts to 0 (e.g. for the overlay filter). However, if such nodes have to deal with
         # audio and video (like the slide node), the video and audio needs to be kept in sync. Actions like resetting
@@ -98,6 +123,9 @@ class InputFileJNode(JGraphNode):
         
         @dataclass
         class StreamData:
+            stream_id: int
+            # list as a stream might be selected multiple times in different contexts
+            graph_ids: list[str]
             probe_stream: FFProbeStream
             type: FStreamType
             time_base: Fraction
@@ -125,11 +153,20 @@ class InputFileJNode(JGraphNode):
         stream_data_by_id: dict[int, StreamData] = {}
         
         for specifier in self.streams:
-            if specifier.stream_id >= len(probe_streams):
-                raise ValueError(f"Invalid Stream ID {specifier.stream_id} for file {file} which has only"
-                                 f" {len(probe_streams)} streams")
+            context_stream_ids = context_stream_mappings[specifier.stream_id_context]
+            if specifier.stream_id >= len(context_stream_ids):
+                raise ValueError(f"Invalid Stream {specifier.stream_id} in context {specifier.stream_id_context} for"
+                                 f" file {file}. That context only has {len(context_stream_ids)} streams")
+            stream_id = context_stream_ids[specifier.stream_id]
+            probe_stream = probe_streams[stream_id]
+            
+            logger.info(f"Stream {specifier.stream_id} in context {specifier.stream_id_context} maps to global stream ID {stream_id}")
             
-            probe_stream = probe_streams[specifier.stream_id]
+            if stream_id in stream_data_by_id:
+                stream_data = stream_data_by_id[stream_id]
+                logger.warning(f"Stream {stream_id} was mapped again to graph id '{specifier.graph_id}' (already mapped to '{"', '".join(stream_data.graph_ids)}')")
+                stream_data.graph_ids.append(specifier.graph_id)
+                continue
             
             if isinstance(probe_stream, FFProbeVideoStream):
                 stream_type = FStreamType.VIDEO
@@ -137,22 +174,25 @@ class InputFileJNode(JGraphNode):
                 stream_type = FStreamType.AUDIO
             else:
                 raise ValueError(f"Unknown stream type {type(probe_stream)} returned by ffprobe for file {file}"
-                                 f" for stream {specifier.stream_id}")
+                                 f" for stream {stream_id}")
             
             time_base = parse_fraction(probe_stream.raw_data["time_base"])
             if "duration_ts" in probe_stream.raw_data:
                 duration_ts = int(probe_stream.raw_data["duration_ts"])
             else:
                 duration_ts = Fraction(probe_stream.duration_ns, int(math.pow(10, 9))) / time_base
+            
             stream_data = StreamData(
+                stream_id=stream_id,
+                graph_ids=[specifier.graph_id],
                 probe_stream=probe_stream,
                 type=stream_type,
                 time_base=time_base,
                 start_pts=int(probe_stream.raw_data["start_pts"]),
                 duration_ts=duration_ts
             )
-            stream_data_by_id[specifier.stream_id] = stream_data
-            logger.info(f"Input file {file}, stream {specifier.stream_id} (in graph as {specifier.graph_id}): {stream_data}")
+            stream_data_by_id[stream_id] = stream_data
+            logger.info(f"Input file {file}, stream {stream_id} (in graph as {specifier.graph_id}): {stream_data}")
             
             if latest_start_time is None:
                 latest_start_time = stream_data.start_time_sec
@@ -169,8 +209,7 @@ class InputFileJNode(JGraphNode):
         
         stream_fids: dict[int, FID] = {}
         
-        for specifier in self.streams:
-            stream_data = stream_data_by_id[specifier.stream_id]
+        for stream_id, stream_data in stream_data_by_id.items():
             
             final_start_pts = round(latest_start_time / stream_data.time_base)
             final_end_pts = round(earliest_end_time / stream_data.time_base)
@@ -180,10 +219,11 @@ class InputFileJNode(JGraphNode):
                           or final_duration_ts != stream_data.duration_ts)
             needs_start_pts_reset = needs_trim or final_start_pts != 0
             
-            metadata_fid = FID(stream_data.type, name_hint=f"jid_{specifier.graph_id}")
+            jid_name_hint = f"jid_{"_".join(stream_data.graph_ids)}"
+            metadata_fid = FID(stream_data.type, name_hint=jid_name_hint)
             
             if needs_start_pts_reset:
-                input_fid = FID(stream_data.type, name_hint=f"jid_{specifier.graph_id}_pre_aligned")
+                input_fid = FID(stream_data.type, name_hint=f"{jid_name_hint}_pre_aligned")
             else:
                 input_fid = metadata_fid
             
@@ -193,7 +233,7 @@ class InputFileJNode(JGraphNode):
                 chain = []
                 
                 if needs_trim:
-                    logger.info(f"Trimming stream {specifier.stream_id} to"
+                    logger.info(f"Trimming stream {stream_id} to"
                                 f" start_pts={final_start_pts} ({float(final_start_pts * stream_data.time_base):.8f}s),"
                                 f" end_pts={final_end_pts} ({float(final_end_pts * stream_data.time_base):.8f}s)"
                                 f" (duration {final_duration_ts} or {float(final_duration_ts * stream_data.time_base):.8f}s)")
@@ -202,7 +242,7 @@ class InputFileJNode(JGraphNode):
                         "end_pts": final_end_pts
                     }))
                 
-                logger.info(f"Resetting STARTPTS for stream {specifier.stream_id} to 0")
+                logger.info(f"Resetting STARTPTS for stream {stream_id} to 0")
                 chain.append(ChainFFilter(f"{filter_prefix}setpts", {
                     "expr": "PTS-STARTPTS"
                 }))
@@ -226,8 +266,9 @@ class InputFileJNode(JGraphNode):
             else:
                 raise AssertionError("Unknown type")
             
-            stream_fids[specifier.stream_id] = input_fid
-            context.add_stream_metadata(specifier.graph_id, metadata)
+            stream_fids[stream_id] = input_fid
+            for graph_id in stream_data.graph_ids:
+                context.add_stream_metadata(graph_id, metadata)
         
         context.add_input_file(file, stream_fids)