Source code for pypika.dialects

from copy import copy

from pypika.enums import Dialects
from pypika.queries import (
    Query,
    QueryBuilder,
)
from pypika.terms import (
    ArithmeticExpression,
    Field,
    Function,
    Star,
    ValueWrapper,
)
from pypika.utils import (
    QueryException,
    builder,
)


class SnowFlakeQueryBuilder(QueryBuilder):
    QUOTE_CHAR = None
    ALIAS_QUOTE_CHAR = '"'

    def __init__(self):
        super(SnowFlakeQueryBuilder, self).__init__(dialect=Dialects.SNOWFLAKE)


class SnowflakeQuery(Query):
    """
    Defines a query class for use with Snowflake.
    """

    @classmethod
    def _builder(cls):
        return SnowFlakeQueryBuilder()


class MySQLQueryBuilder(QueryBuilder):
    QUOTE_CHAR = '`'

    def __init__(self):
        super(MySQLQueryBuilder, self).__init__(dialect=Dialects.MYSQL,
                                                wrap_union_queries=False)
        self._duplicate_updates = []
        self._modifiers = []

    def __copy__(self):
        newone = super(MySQLQueryBuilder, self).__copy__()
        newone._duplicate_updates = copy(self._duplicate_updates)
        return newone

    @builder
    def on_duplicate_key_update(self, field, value):
        field = Field(field) if not isinstance(field, Field) else field
        self._duplicate_updates.append((field, ValueWrapper(value)))

    def get_sql(self, **kwargs):
        self._set_kwargs_defaults(kwargs)
        querystring = super(MySQLQueryBuilder, self).get_sql(**kwargs)
        if querystring and self._duplicate_updates:
            querystring += self._on_duplicate_key_update_sql(**kwargs)
        return querystring

    def _on_duplicate_key_update_sql(self, **kwargs):
        return ' ON DUPLICATE KEY UPDATE {updates}' \
            .format(updates=','.join('{field}={value}'
                                     .format(field=field.get_sql(**kwargs),
                                             value=value.get_sql(**kwargs))
                                     for field, value in self._duplicate_updates))

    @builder
    def modifier(self, value):
        """
        Adds a modifier such as SQL_CALC_FOUND_ROWS to the query.
        https://dev.mysql.com/doc/refman/5.7/en/select.html

        :param value: The modifier value e.g. SQL_CALC_FOUND_ROWS
        """
        self._modifiers.append(value)

    def _select_sql(self, **kwargs):
        """
        Overridden function to generate the SELECT part of the SQL statement,
        with the addition of the a modifier if present.
        """
        return 'SELECT {distinct}{modifier}{select}'.format(
              distinct='DISTINCT ' if self._distinct else '',
              modifier='{} '.format(' '.join(self._modifiers)) if self._modifiers else '',
              select=','.join(term.get_sql(with_alias=True, subquery=True, **kwargs)
                              for term in self._selects),
        )


class MySQLQuery(Query):
    """
    Defines a query class for use with MySQL.
    """

    @classmethod
    def _builder(cls):
        return MySQLQueryBuilder()


class VerticaQueryBuilder(QueryBuilder):
    def __init__(self):
        super(VerticaQueryBuilder, self).__init__(dialect=Dialects.VERTICA)
        self._hint = None

    @builder
    def hint(self, label):
        self._hint = label

    def get_sql(self, *args, **kwargs):
        sql = super(VerticaQueryBuilder, self).get_sql(*args, **kwargs)

        if self._hint is not None:
            sql = ''.join([sql[:7],
                           '/*+label({hint})*/'.format(hint=self._hint),
                           sql[6:]])

        return sql


class VerticaQuery(Query):
    """
    Defines a query class for use with Vertica.
    """

    @classmethod
    def _builder(cls):
        return VerticaQueryBuilder()


class OracleQueryBuilder(QueryBuilder):
    def __init__(self):
        super(OracleQueryBuilder, self).__init__(dialect=Dialects.ORACLE)

    def get_sql(self, *args, **kwargs):
        return super(OracleQueryBuilder, self).get_sql(*args, groupby_alias=False, **kwargs)


class OracleQuery(Query):
    """
    Defines a query class for use with Oracle.
    """

    @classmethod
    def _builder(cls):
        return OracleQueryBuilder()


class PostgreQueryBuilder(QueryBuilder):
    def __init__(self):
        super(PostgreQueryBuilder, self).__init__(dialect=Dialects.POSTGRESQL)
        self._returns = []
        self._return_star = False
        self._on_conflict_field = None
        self._on_conflict_do_nothing = False
        self._on_conflict_updates = []

    def __copy__(self):
        newone = super(PostgreQueryBuilder, self).__copy__()
        newone._returns = copy(self._returns)
        newone._on_conflict_updates = copy(self._on_conflict_updates)
        return newone

    @builder
    def on_conflict(self, target_field):
        if not self._insert_table:
            raise QueryException('On conflict only applies to insert query')
        if isinstance(target_field, str):
            self._on_conflict_field = self._conflict_field_str(target_field)
        elif isinstance(target_field, Field):
            self._on_conflict_field = target_field

    @builder
    def do_nothing(self):
        if len(self._on_conflict_updates) > 0:
            raise QueryException('Can not have two conflict handlers')
        self._on_conflict_do_nothing = True

    @builder
    def do_update(self, update_field, update_value):
        if self._on_conflict_do_nothing:
            raise QueryException('Can not have two conflict handlers')

        if isinstance(update_field, str):
            field = self._conflict_field_str(update_field)
        elif isinstance(update_field, Field):
            field = update_field
        self._on_conflict_updates.append((field, ValueWrapper(update_value)))

    def _conflict_field_str(self, term):
        if self._insert_table:
            return Field(term, table=self._insert_table)

    def _on_conflict_sql(self, **kwargs):
        if not self._on_conflict_do_nothing and len(self._on_conflict_updates) == 0:
            if not self._on_conflict_field:
                return ''
            else:
                raise QueryException('No handler defined for on conflict')
        else:
            conflict_query = ' ON CONFLICT'
            if self._on_conflict_field:
                conflict_query += ' (' + self._on_conflict_field.get_sql(with_alias=True, **kwargs) + ')'
            if self._on_conflict_do_nothing:
                conflict_query += ' DO NOTHING'
            elif len(self._on_conflict_updates) > 0:
                if self._on_conflict_field:
                    conflict_query += ' DO UPDATE SET {updates}'.format(
                          updates=','.join(
                                '{field}={value}'.format(
                                      field=field.get_sql(**kwargs),
                                      value=value.get_sql(**kwargs)) for field, value in self._on_conflict_updates
                          )
                    )
                else:
                    raise QueryException('Can not have fieldless on conflict do update')

            return conflict_query

    @builder
    def returning(self, *terms):
        for term in terms:
            if isinstance(term, Field):
                self._return_field(term)
            elif isinstance(term, str):
                self._return_field_str(term)
            elif isinstance(term, ArithmeticExpression):
                self._return_other(term)
            elif isinstance(term, Function):
                raise QueryException('Aggregate functions are not allowed in returning')
            else:
                self._return_other(self.wrap_constant(term, self._wrapper_cls))

    def _validate_returning_term(self, term):
        for field in term.fields():
            if not any([self._insert_table, self._update_table, self._delete_from]):
                raise QueryException('Returning can\'t be used in this query')
            if (
                  field.table not in {self._insert_table, self._update_table}
                  and term not in self._from
            ):
                raise QueryException('You can\'t return from other tables')

    def _set_returns_for_star(self):
        self._returns = [returning
                         for returning in self._returns
                         if not hasattr(returning, 'table')]
        self._return_star = True

    def _return_field(self, term):
        if self._return_star:
            # Do not add select terms after a star is selected
            return

        self._validate_returning_term(term)

        if isinstance(term, Star):
            self._set_returns_for_star()

        self._returns.append(term)

    def _return_field_str(self, term):
        if term == '*':
            self._set_returns_for_star()
            self._returns.append(Star())
            return

        if self._insert_table:
            self._return_field(Field(term, table=self._insert_table))
        elif self._update_table:
            self._return_field(Field(term, table=self._update_table))
        elif self._delete_from:
            self._return_field(Field(term, table=self._from[0]))
        else:
            raise QueryException('Returning can\'t be used in this query')

    def _return_other(self, function):
        self._validate_returning_term(function)
        self._returns.append(function)

    def _returning_sql(self, **kwargs):
        return ' RETURNING {returning}'.format(
              returning=','.join(term.get_sql(with_alias=True, **kwargs)
                                 for term in self._returns),
        )

    def get_sql(self, with_alias=False, subquery=False, **kwargs):
        querystring = super(PostgreQueryBuilder, self).get_sql(with_alias, subquery, **kwargs)
        querystring += self._on_conflict_sql()
        if self._returns:
            querystring += self._returning_sql()
        return querystring


class PostgreSQLQuery(Query):
    """
    Defines a query class for use with PostgreSQL.
    """

    @classmethod
    def _builder(cls):
        return PostgreQueryBuilder()


class RedshiftQuery(Query):
    """
    Defines a query class for use with Amazon Redshift.
    """

    @classmethod
    def _builder(cls):
        return QueryBuilder(dialect=Dialects.REDSHIFT)


class MSSQLQueryBuilder(QueryBuilder):
    def __init__(self):
        super(MSSQLQueryBuilder, self).__init__(dialect=Dialects.MSSQL)
        self._top = None

    @builder
    def top(self, value):
        """
        Implements support for simple TOP clauses.

        Does not include support for PERCENT or WITH TIES.

        https://docs.microsoft.com/en-us/sql/t-sql/queries/top-transact-sql?view=sql-server-2017
        """
        try:
            self._top = int(value)
        except ValueError:
            raise QueryException('TOP value must be an integer')

    def get_sql(self, *args, **kwargs):
        return super(MSSQLQueryBuilder, self).get_sql(*args, groupby_alias=False, **kwargs)

    def _top_sql(self):
        if self._top:
            return 'TOP ({}) '.format(self._top)
        else:
            return ''

    def _select_sql(self, **kwargs):
        return 'SELECT {distinct}{top}{select}'.format(
              top=self._top_sql(),
              distinct='DISTINCT ' if self._distinct else '',
              select=','.join(term.get_sql(with_alias=True, subquery=True, **kwargs)
                              for term in self._selects),
        )


class MSSQLQuery(Query):
    """
    Defines a query class for use with Microsoft SQL Server.
    """

    @classmethod
    def _builder(cls):
        return MSSQLQueryBuilder()


class ClickHouseQuery(Query):
    """
    Defines a query class for use with Yandex ClickHouse.
    """

    @classmethod
    def _builder(cls):
        return QueryBuilder(dialect=Dialects.CLICKHOUSE, wrap_union_queries=False)


class SQLLiteValueWrapper(ValueWrapper):
    def get_value_sql(self, *args, **kwargs):
        if isinstance(self.value, bool):
            return '1' if self.value else '0'
        return super().get_value_sql(*args, **kwargs)


class SQLLiteQuery(Query):
    """
    Defines a query class for use with Microsoft SQL Server.
    """

    @classmethod
    def _builder(cls):
        return QueryBuilder(dialect=Dialects.SQLLITE, wrapper_cls=SQLLiteValueWrapper)