Skip to content

Commit

Permalink
Merge pull request #84 from kayak/remove_dialect_magic
Browse files Browse the repository at this point in the history
Removed function magic for sql dialects
  • Loading branch information
twheys authored Dec 12, 2017
2 parents aac1099 + 9178398 commit b90ac7f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 130 deletions.
45 changes: 4 additions & 41 deletions pypika/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
"""
from pypika.enums import (
SqlTypes,
Dialects,
)
from pypika.terms import (
AggregateFunction,
Function,
Star,
AggregateFunction,
ValueWrapper,
)
from pypika.utils import builder

Expand Down Expand Up @@ -191,50 +189,15 @@ class SplitPart(Function):
def __init__(self, term, delimiter, index, alias=None):
super(SplitPart, self).__init__('SPLIT_PART', term, delimiter, index, alias=alias)

def get_name_for_dialect(self, dialect=None):
return {
Dialects.MYSQL: 'SUBSTRING_INDEX',
Dialects.POSTGRESQL: 'SPLIT_PART',
Dialects.REDSHIFT: 'SPLIT_PART',
Dialects.VERTICA: 'SPLIT_PART',
Dialects.ORACLE: 'REGEXP_SUBSTR',
}.get(dialect, None)

def get_args_for_dialect(self, dialect=None):
term, delimiter, index = self.args

return {
Dialects.MYSQL: (term, delimiter, index),
Dialects.POSTGRESQL: (term, delimiter, index),
Dialects.REDSHIFT: (term, delimiter, index),
Dialects.VERTICA: (term, delimiter, index),
Dialects.ORACLE: (term, ValueWrapper('[^{}]+'.format(delimiter.value)), 1, index)
}.get(dialect, None)

class RegexpMatches(Function):
def __init__(self, term, pattern, modifiers, alias=None):
super(RegexpMatches, self).__init__('REGEXP_MATCHES', term, pattern, modifiers, alias=alias)

class RegexpLike(Function):
def __init__(self, term, pattern, modifiers, alias=None):
super(RegexpLike, self).__init__('REGEXP_LIKE', term, pattern, modifiers, alias=alias)

def get_name_for_dialect(self, dialect=None):
return {
Dialects.POSTGRESQL: 'REGEXP_MATCHES',
Dialects.REDSHIFT: 'REGEXP_MATCHES',
Dialects.VERTICA: 'REGEXP_LIKE',
Dialects.ORACLE: 'REGEXP_LIKE',
}.get(dialect, self.name)

def get_args_for_dialect(self, dialect=None):
term, pattern, modifiers = self.args

return {
Dialects.POSTGRESQL: (term, pattern, modifiers),
Dialects.REDSHIFT: (term, pattern, modifiers),
Dialects.VERTICA: (term, pattern, modifiers),
Dialects.ORACLE: (term, pattern, modifiers)
}.get(dialect, None)


# Date Functions
class Now(Function):
def __init__(self, alias=None):
Expand Down
36 changes: 3 additions & 33 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from pypika.utils import (
CaseException,
DialectNotSupported,
alias_sql,
builder,
ignoredeepcopy,
Expand Down Expand Up @@ -655,48 +654,19 @@ def get_special_params_sql(self, **kwargs):
def get_function_sql(self, **kwargs):
special_params_sql = self.get_special_params_sql(**kwargs)

dialect = kwargs.get('dialect', None)
dialect_name = self.get_name_for_dialect(dialect=dialect)
dialect_args = self.get_args_for_dialect(dialect=dialect)

if dialect_name is None or dialect_args is None:
raise DialectNotSupported('The function {} has no support for {} dialect'.format(self.name, dialect))

return '{name}({args}{special})'.format(
name=dialect_name,
name=self.name,
args=','.join(p.get_sql(with_alias=False, **kwargs)
if hasattr(p, 'get_sql')
else str(p)
for p in dialect_args),
for p in self.args),
special=(' ' + special_params_sql) if special_params_sql else '',
)

def get_name_for_dialect(self, dialect=None):
"""
This function will transform the original function name into the equivalent for different dialects.
In practice this method should be overriden on subclasses whenever different dialects support is
required. Otherwise the original name will be used.
:param dialect: one of the options in the Dialects enum.
:return: the function name that should be used by the get_function_sql method when serializing.
"""
return self.name

def get_args_for_dialect(self, dialect=None):
"""
This function will transform the original function args into the equivalent for different dialects.
In practice this method should be overriden on subclasses whenever different dialects support is
required. Otherwise the original arguments will be used.
:param dialect: one of the options in the Dialects enum.
:return: the function args that should be used by the get_function_sql method when serializing.
"""
return self.args

def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, **kwargs):
# FIXME escape

function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, **kwargs)
function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char)

if not with_alias or self.alias is None:
return function_sql
Expand Down
68 changes: 12 additions & 56 deletions pypika/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,20 @@
import unittest

from pypika import (
Query as Q,
Table as T,
Field as F,
functions as fn,
CaseException,
Case,
Interval,
CaseException,
DatePart,
MySQLQuery,
Field as F,
Interval,
Query as Q,
Table as T,
VerticaQuery,
PostgreSQLQuery,
RedshiftQuery,
OracleQuery,
functions as fn,
)
from pypika.enums import (
Dialects,
SqlTypes,
)
from pypika.enums import (SqlTypes,
Dialects)
from pypika.utils import DialectNotSupported

__author__ = "Timothy Heys"
__email__ = "theys@kayak.com"
Expand Down Expand Up @@ -369,61 +366,20 @@ def test__length__field(self):
class SplitPartFunctionTests(unittest.TestCase):
t = T('abc')

def test__split_part__field_with_vertica_dialect(self):
def test__split_part(self):
q = VerticaQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SPLIT_PART(\"foo\",\'|\',3) FROM \"abc\"", str(q))

def test__split_part__field_with_mysql_dialect(self):
q = MySQLQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SUBSTRING_INDEX(`foo`,\'|\',3) FROM `abc`", str(q))

def test__split_part__field_with_postgresql_dialect(self):
q = PostgreSQLQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SPLIT_PART(\"foo\",\'|\',3) FROM \"abc\"", str(q))

def test__split_part__field_with_redshift_dialect(self):
q = RedshiftQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SPLIT_PART(\"foo\",\'|\',3) FROM \"abc\"", str(q))

def test__split_part__field_with_oracle_dialect(self):
q = OracleQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT REGEXP_SUBSTR(\"foo\",\'[^|]+\',1,3) FROM \"abc\"", str(q))


class RegexpLikeFunctionTests(unittest.TestCase):
t = T('abc')

def test__regexp_like__field_with_vertica_dialect(self):
def test__regexp_like(self):
q = VerticaQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_LIKE(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))

def test__regexp_like__field_with_mysql_dialect(self):
q = MySQLQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

with self.assertRaises(DialectNotSupported):
str(q)

def test__regexp_like__field_with_postgresql_dialect(self):
q = PostgreSQLQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_MATCHES(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))

def test__regexp_like__field_with_redshift_dialect(self):
q = RedshiftQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_MATCHES(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))

def test__regexp_like__field_with_oracle_dialect(self):
q = OracleQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_LIKE(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))


class CastTests(unittest.TestCase):
t = T('abc')
Expand Down

0 comments on commit b90ac7f

Please sign in to comment.