Skip to content

Commit 398e27d

Browse files
authored
Add failure tests for stealing subgraphs. Minor fix in pipeline validation. (NVIDIA#5518)
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
1 parent 9420fb8 commit 398e27d

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

dali/python/nvidia/dali/pipeline.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,9 @@ def __exit__(self, type, value, traceback):
722722
def _require_unique_names(self):
723723
ops_by_name = {}
724724
for op in self._ops:
725-
ops = ops_by_name.get(op.name, [])
725+
ops = ops_by_name.get(op.name, None)
726+
if ops is None:
727+
ops = ops_by_name[op.name] = []
726728
ops.append(op)
727729
duplicate = {}
728730
foreign = False

dali/test/python/checkpointing/test_dali_checkpointing.py

+19
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from nvidia.dali.auto_aug import trivial_augment as ta
3737
from reader.test_numpy import is_gds_supported
3838
from nose.plugins.attrib import attr
39+
from nose_utils import assert_raises
40+
3941

4042
reader_signed_off = create_sign_off_decorator()
4143
random_signed_off = create_sign_off_decorator()
@@ -1185,6 +1187,23 @@ def pipeline():
11851187
compare_pipelines(pipe2, pipe3, batch_size, 5)
11861188

11871189

1190+
def test_unsupported_dangling_subgraph():
1191+
es = fn.external_source("asdf")
1192+
1193+
@pipeline_def(batch_size=1, num_threads=1, device_id=None, enable_checkpointing=True)
1194+
def pipe(arg):
1195+
return arg + 0
1196+
1197+
p = pipe(es)
1198+
1199+
with assert_raises(
1200+
RuntimeError,
1201+
glob="The pipeline does not support checkpointing*"
1202+
"because it contains operator*outside the pipeline*",
1203+
):
1204+
p.build()
1205+
1206+
11881207
unsupported_readers = [
11891208
"experimental.readers.fits",
11901209
]

dali/test/python/test_pipeline.py

+18
Original file line numberDiff line numberDiff line change
@@ -2204,3 +2204,21 @@ def test_regression_without_current_pipeline2():
22042204
data = fn.external_source(source=[1, 2, 3], batch=False, cycle=True)
22052205
pipe.set_outputs(data.gpu())
22062206
pipe.build()
2207+
2208+
2209+
def test_subgraph_stealing():
2210+
p1 = Pipeline(batch_size=1, device_id=None, num_threads=1)
2211+
p2 = Pipeline(batch_size=1, device_id=None, num_threads=1)
2212+
with p1:
2213+
es1 = fn.external_source(source=[1, 2, 3], batch=False)
2214+
x = es1 + 1
2215+
p1.set_outputs(x)
2216+
with p2:
2217+
es2 = fn.external_source(source=[1, 2, 3], batch=False)
2218+
p2.set_outputs(x + es2)
2219+
p1.build()
2220+
with assert_raises(
2221+
RuntimeError,
2222+
glob="The pipeline is invalid because it contains operators with non-unique names",
2223+
):
2224+
p2.build()

0 commit comments

Comments
 (0)