diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index 073172efb..4ab94faf9 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -119,10 +119,10 @@ It defines a common format for representing models, including the network struct from doctr.models import vitstr_small from doctr.models.utils import export_model_to_onnx - batch_size = 16 + batch_size = 1 input_shape = (3, 32, 128) model = vitstr_small(pretrained=True, exportable=True) - dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32) + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) model_path = export_model_to_onnx( model, model_name="vitstr.onnx", @@ -137,10 +137,10 @@ It defines a common format for representing models, including the network struct from doctr.models import vitstr_small from doctr.models.utils import export_model_to_onnx - batch_size = 16 + batch_size = 1 input_shape = (32, 128, 3) model = vitstr_small(pretrained=True, exportable=True) - dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")] + dummy_input = [tf.TensorSpec([batch_size, *input_shape], tf.float32, name="input")] model_path, output = export_model_to_onnx( model, model_name="vitstr.onnx",