diff --git a/compiler/dialects.py b/compiler/dialects.py index 81ff306..28c8d5c 100755 --- a/compiler/dialects.py +++ b/compiler/dialects.py @@ -415,6 +415,7 @@ def BuiltInFunctions(self): 'Least': 'LEAST(%s)', 'Greatest': 'GREATEST(%s)', 'ToString': 'CAST(%s AS TEXT)', + 'ToFloat64': 'CAST(%s AS DOUBLE)', 'DateAddDay': "DATE({0}, {1} || ' days')", 'DateDiffDay': "CAST(JULIANDAY({0}) - JULIANDAY({1}) AS INT64)", 'CurrentTimestamp': 'GET_CURRENT_TIMESTAMP()', diff --git a/compiler/functors.py b/compiler/functors.py index f1e2c02..e39d0ef 100755 --- a/compiler/functors.py +++ b/compiler/functors.py @@ -622,8 +622,9 @@ def CountNils(self, node): # Do not walk into: # predicate value literals, # combine expressions because they will be trivially - # null. - taboo=['the_predicate', 'combine']) + # null, + # lists of satelites. + taboo=['the_predicate', 'combine', 'satellites']) rules_per_predicate[p] = rules_per_predicate.get(p,0) + ( c.nil_count == 0) is_nothing = set() diff --git a/compiler/universe.py b/compiler/universe.py index c5a04d8..3abc201 100755 --- a/compiler/universe.py +++ b/compiler/universe.py @@ -635,11 +635,33 @@ def CheckDistinctConsistency(self): 'Either all rules of a predicate must be distinct denoted ' 'or none. Predicate {warning}{p}{end} violates it.', dict(p=p)), r['full_text']) - + + def InscribeOrbits(self, rules, depth_map): + """Writes satellites from annotation to a field in the body.""" + # Satellites are written to master and master is written to + # satellites. You are responsible for those whom you tamed. + master = {} + for p, args in depth_map.items(): + for s in args.get('satellites', []): + master[s['predicate_name']] = p + for r in rules: + p = r['head']['predicate_name'] + if p in depth_map: + if 'satellites' not in depth_map[p]: + continue + if 'body' not in r: + r['body'] = {'conjunction': {'conjunct': []}} + r['body']['conjunction']['satellites'] = depth_map[p]['satellites'] + if p in master: + if 'body' not in r: + r['body'] = {'conjunction': {'conjunct': []}} + r['body']['conjunction']['satellites'] = [{'predicate_name': master[p]}] + def UnfoldRecursion(self, rules): annotations = Annotations(rules, {}) - f = functors.Functors(rules) depth_map = annotations.annotations.get('@Recursive', {}) + self.InscribeOrbits(rules, depth_map) + f = functors.Functors(rules) # Annotations are not ready at this point. # if (self.execution.annotations.Engine() == 'duckdb'): # for p in depth_map: diff --git a/integration_tests/duckdb_smoothed_winmove_test.l b/integration_tests/duckdb_smoothed_winmove_test.l new file mode 100644 index 0000000..8243c5d --- /dev/null +++ b/integration_tests/duckdb_smoothed_winmove_test.l @@ -0,0 +1,35 @@ +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +@Engine("duckdb"); + +Move("a", "b"); +Move("b", "a"); +Move("b", "c"); +Move("c", "d"); + +Node(x) distinct :- Move(a, b), x in [a, b]; + +N() += N(); +N() += 1; + +@Recursive(Win, 101, satellites: [SmoothedWin, N]); +Win(x) :- Move(x, y), ~Win(y); + +SmoothedWin(x) += SmoothedWin(x) * (N() - 1) / N(); +SmoothedWin(x) += ToFloat64(!~ Win(x)) / N() :- Node(x); + +@OrderBy(Test, "col0"); +Test(x, SmoothedWin(x)); diff --git a/integration_tests/duckdb_smoothed_winmove_test.txt b/integration_tests/duckdb_smoothed_winmove_test.txt new file mode 100644 index 0000000..1dda546 --- /dev/null +++ b/integration_tests/duckdb_smoothed_winmove_test.txt @@ -0,0 +1,8 @@ ++------+--------------------+ +| col0 | col1 | ++------+--------------------+ +| a | 0.5049504950495048 | +| b | 0.5049504950495048 | +| c | 1.0 | +| d | 0.0 | ++------+--------------------+ \ No newline at end of file diff --git a/integration_tests/run_tests.py b/integration_tests/run_tests.py index 31ae3fa..9c64ad8 100755 --- a/integration_tests/run_tests.py +++ b/integration_tests/run_tests.py @@ -110,6 +110,8 @@ def RunAll(test_presto=False, test_trino=False): RunTest("sqlite_element_test") RunTest("sqlite_functor_over_constant_test") + RunTest("duckdb_smoothed_winmove_test", + use_concertina=True) RunTest("duckdb_iteration_closure_test", use_concertina=True) RunTest("duckdb_stop_test",