From 623ffb33bea81fcdae38e7b0c2929ac9bc7da503 Mon Sep 17 00:00:00 2001 From: Mihail Feraru Date: Sun, 11 Jun 2023 19:49:57 +0300 Subject: [PATCH] Add support for multi-line SQL and commands. --- CHANGES.txt | 2 + crate/crash/command.py | 73 +++++++++------------- setup.py | 1 + tests/test_commands.py | 127 ++++++++++++++++++++++++++++++++++++-- tests/test_integration.py | 6 +- 5 files changed, 158 insertions(+), 51 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 5b677183..9e698f09 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -6,6 +6,8 @@ Unreleased ========== - Fix inconsistent spacing around printed runtime. Thank you, @hammerhead. +- Add support for multi-line input of commands and SQL statements for both + copy-pasting inside the crash shell and input pipes into crash. 2023/02/16 0.29.0 ================= diff --git a/crate/crash/command.py b/crate/crash/command.py index 06381c17..7033fb61 100644 --- a/crate/crash/command.py +++ b/crate/crash/command.py @@ -33,6 +33,7 @@ from getpass import getpass from operator import itemgetter +import sqlparse import urllib3 from platformdirs import user_config_dir, user_data_dir from urllib3.exceptions import LocationParseError @@ -176,31 +177,6 @@ def inner_fn(self, *args): return inner_fn -def _parse_statements(lines): - """Return a generator of statements - - Args: A list of strings that can contain one or more statements. - Statements are separated using ';' at the end of a line - Everything after the last ';' will be treated as the last statement. - - >>> list(_parse_statements(['select * from ', 't1;', 'select name'])) - ['select * from\\nt1', 'select name'] - - >>> list(_parse_statements(['select * from t1;', ' '])) - ['select * from t1'] - """ - lines = (l.strip() for l in lines if l) - lines = (l for l in lines if l and not l.startswith('--')) - parts = [] - for line in lines: - parts.append(line.rstrip(';')) - if line.endswith(';'): - yield '\n'.join(parts) - parts[:] = [] - if parts: - yield '\n'.join(parts) - - class CrateShell: def __init__(self, @@ -274,19 +250,28 @@ def pprint(self, rows, cols): self.get_num_columns()) self.output_writer.write(result) - def process_iterable(self, stdin): - any_statement = False - for statement in _parse_statements(stdin): - self._exec_and_print(statement) - any_statement = True - return any_statement + def process_iterable(self, iterable): + self._process_lines([line for text in iterable for line in text.split('\n')]) def process(self, text): - if text.startswith('\\'): - self._try_exec_cmd(text.lstrip('\\')) - else: - for statement in _parse_statements([text]): - self._exec_and_print(statement) + self._process_lines(text.split('\n')) + + def _process_lines(self, lines): + sql_lines = [] + for line in lines: + line = line.strip() + if line.startswith('\\'): + self._process_sql('\n'.join(sql_lines)) + self._try_exec_cmd(line.lstrip('\\')) + sql_lines = [] + else: + sql_lines.append(line) + self._process_sql('\n'.join(sql_lines)) + + def _process_sql(self, text): + sql = sqlparse.format(text, strip_comments=False) + for statement in sqlparse.split(sql): + self._exec_and_print(statement) def exit(self): self.close() @@ -498,14 +483,15 @@ def stmt_type(statement): return re.findall(r'[\w]+', statement)[0].upper() -def get_stdin(): +def get_lines_from_stdin(): """ - Get data from stdin, if any + Get data line by line from stdin, if any """ - if not sys.stdin.isatty(): - for line in sys.stdin: - yield line - return + if sys.stdin.isatty(): + return + + for line in sys.stdin: + yield line def host_and_port(host_or_port): @@ -622,7 +608,8 @@ def save_and_exit(): cmd.process(args.command) save_and_exit() - if cmd.process_iterable(get_stdin()): + if not sys.stdin.isatty(): + cmd.process_iterable(get_lines_from_stdin()) save_and_exit() from .repl import loop diff --git a/setup.py b/setup.py index 59622d45..28ddb856 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ 'platformdirs<4', 'prompt-toolkit>=3.0,<4', 'tabulate>=0.9,<0.10', + 'sqlparse>=0.4.4,<0.5.0' ] diff --git a/tests/test_commands.py b/tests/test_commands.py index 3c998c41..55804eba 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -229,16 +229,24 @@ def test_sql_comments(self): -- Another SELECT statement. SELECT 2; -- Yet another SELECT statement with an inline comment. --- Other than the regular comments, it gets passed through to the database server. +-- Comments get passed through to the database server. SELECT /* this is a comment */ 3; +SELECT /* this is a multi-line +comment */ 4; """ cmd = CrateShell() cmd._exec_and_print = MagicMock() cmd.process_iterable(sql.splitlines()) - cmd._exec_and_print.assert_has_calls([ - call("SELECT 1"), - call("SELECT 2"), - call("SELECT /* this is a comment */ 3"), + self.assertListEqual(cmd._exec_and_print.mock_calls, [ + call("-- Just a dummy SELECT statement.\nSELECT 1;"), + call("-- Another SELECT statement.\nSELECT 2;"), + call('\n'.join([ + "-- Yet another SELECT statement with an inline comment.", + "-- Comments get passed through to the database server.", + "SELECT /* this is a comment */ 3;" + ]) + ), + call('SELECT /* this is a multi-line\ncomment */ 4;') ]) def test_js_comments(self): @@ -262,3 +270,112 @@ def test_js_comments(self): cmd.process(sql) self.assertEqual(1, cmd._exec_and_print.call_count) self.assertIn("CREATE FUNCTION", cmd._exec_and_print.mock_calls[0].args[0]) + + +class MultipleStatementsTest(TestCase): + + def test_single_line_multiple_sql_statements(self): + cmd = CrateShell() + cmd._exec_and_print = MagicMock() + cmd.process("SELECT 1; SELECT 2; SELECT 3;") + self.assertListEqual(cmd._exec_and_print.mock_calls, [ + call("SELECT 1;"), + call("SELECT 2;"), + call("SELECT 3;"), + ]) + + def test_multiple_lines_multiple_sql_statements(self): + cmd = CrateShell() + cmd._exec_and_print = MagicMock() + cmd.process("SELECT 1;\nSELECT 2; SELECT 3;\nSELECT\n4;") + self.assertListEqual(cmd._exec_and_print.mock_calls, [ + call("SELECT 1;"), + call("SELECT 2;"), + call("SELECT 3;"), + call("SELECT\n4;"), + ]) + + def test_single_sql_statement_multiple_lines(self): + """When processing single SQL statements, new lines are preserved.""" + + cmd = CrateShell() + cmd._exec_and_print = MagicMock() + cmd.process("\nSELECT\n1\nWHERE\n2\n=\n3\n;\n") + self.assertListEqual(cmd._exec_and_print.mock_calls, [ + call("SELECT\n1\nWHERE\n2\n=\n3\n;"), + ]) + + def test_multiple_commands_no_sql(self): + cmd = CrateShell() + cmd._try_exec_cmd = MagicMock() + cmd._exec_and_print = MagicMock() + cmd.process("\\?\n\\connect 127.0.0.1") + cmd._try_exec_cmd.assert_has_calls([ + call("?"), + call("connect 127.0.0.1") + ]) + cmd._exec_and_print.assert_not_called() + + def test_commands_and_multiple_sql_statements_interleaved(self): + """Combine all test cases above to be sure everything integrates well.""" + + cmd = CrateShell() + mock_manager = MagicMock() + + cmd._try_exec_cmd = mock_manager.cmd + cmd._exec_and_print = mock_manager.sql + + cmd.process(""" + \\? + SELECT 1 + WHERE 2 = 3; SELECT 4; + \\connect 127.0.0.1 + SELECT 5 + WHERE 6 = 7; + \\check + """) + + self.assertListEqual(mock_manager.mock_calls, [ + call.cmd("?"), + call.sql('SELECT 1\nWHERE 2 = 3;'), + call.sql('SELECT 4;'), + call.cmd("connect 127.0.0.1"), + call.sql('SELECT 5\nWHERE 6 = 7;'), + call.cmd("check"), + ]) + + def test_comments_along_multiple_statements(self): + """Test multiple types of comments along multi-statement input.""" + + cmd = CrateShell() + cmd._exec_and_print = MagicMock() + + cmd.process(""" +-- Multiple statements and comments on same line + +SELECT /* inner comment */ 1; /* this is a single-line comment */ SELECT /* inner comment */ 2; + +-- Multiple statements on multiple lines with multi-line comments between them + +SELECT /* inner comment */ 3; /* this is a +multi-line comment */ SELECT /* inner comment */ 4; + +-- Multiple statements on multiple lines with multi-line comments between and inside them + +SELECT /* inner multi-line +comment */ 5 /* this is a multi-line +comment before statement end */; /* this is another multi-line +comment */ SELECT /* inner multi-line +comment */ 6; + """) + + self.assertListEqual(cmd._exec_and_print.mock_calls, [ + call('-- Multiple statements and comments on same line\n\nSELECT /* inner comment */ 1;'), + call('/* this is a single-line comment */ SELECT /* inner comment */ 2;'), + + call('-- Multiple statements on multiple lines with multi-line comments between them\n\nSELECT /* inner comment */ 3;'), + call('/* this is a\nmulti-line comment */ SELECT /* inner comment */ 4;'), + + call('-- Multiple statements on multiple lines with multi-line comments between and inside them\n\nSELECT /* inner multi-line\ncomment */ 5 /* this is a multi-line\ncomment before statement end */;'), + call('/* this is another multi-line\ncomment */ SELECT /* inner multi-line\ncomment */ 6;') + ]) \ No newline at end of file diff --git a/tests/test_integration.py b/tests/test_integration.py index a5987c52..dae753f4 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,8 +13,8 @@ from crate.crash.command import ( CrateShell, _create_shell, + get_lines_from_stdin, get_parser, - get_stdin, host_and_port, main, noargs_command, @@ -315,7 +315,7 @@ def test_multiline_stdin(self): Newlines must be replaced with whitespaces """ - stmt = ''.join(list(get_stdin())).replace('\n', ' ') + stmt = ''.join(list(get_lines_from_stdin())).replace('\n', ' ') expected = ("create table test( d string ) " "clustered into 2 shards " "with (number_of_replicas=0)") @@ -334,7 +334,7 @@ def test_multiline_stdin_delimiter(self): Newlines must be replaced with whitespaces """ - stmt = ''.join(list(get_stdin())).replace('\n', ' ') + stmt = ''.join(list(get_lines_from_stdin())).replace('\n', ' ') expected = ("create table test( d string ) " "clustered into 2 shards " "with (number_of_replicas=0);")