Source code for fireant.database.snowflake

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from pypika import (
    functions as fn,
    terms,
)
from pypika.dialects import SnowflakeQuery
from .base import Database

try:
    from snowflake import connector as snowflake
except:
    pass

IGNORED_SCHEMAS = {'INFORMATION_SCHEMA'}


[docs]class Trunc(terms.Function): """ Wrapper for TRUNC function for truncating dates. """ def __init__(self, field, date_format, alias=None): super(Trunc, self).__init__('TRUNC', field, date_format, alias=alias) # Setting the fields here means we can access the TRUNC args by name. self.field = field self.date_format = date_format self.alias = alias
[docs]class SnowflakeDatabase(Database): """ Snowflake client. """ # The pypika query class to use for constructing queries query_cls = SnowflakeQuery DATETIME_INTERVALS = { 'hour': 'HH', 'day': 'DD', 'week': 'IW', 'month': 'MM', 'quarter': 'Q', 'year': 'Y' } _private_key = None def __init__(self, user='snowflake', password=None, account='snowflake', database='snowflake', private_key_data=None, private_key_password=None, region=None, warehouse=None, **kwags): super(SnowflakeDatabase, self).__init__(database=database, **kwags) self.user = user self.password = password self.account = account self.private_key_data = private_key_data self.private_key_password = private_key_password self.region = region self.warehouse = warehouse
[docs] def connect(self): import snowflake return snowflake.connector.connect(database=self.database, account=self.account, user=self.user, password=self.password, private_key=self._get_private_key(), region=self.region, warehouse=self.warehouse)
[docs] def trunc_date(self, field, interval): trunc_date_interval = self.DATETIME_INTERVALS.get(str(interval), 'DD') return Trunc(field, trunc_date_interval)
[docs] def date_add(self, field, date_part, interval): return fn.TimestampAdd(str(date_part), interval, field)
def _get_private_key(self): if self._private_key is None: self._private_key = self._load_private_key_data() return self._private_key def _load_private_key_data(self): if self.private_key_data is None: return None private_key_password = None \ if self.private_key_password is None \ else self.private_key_password.encode() pkey = serialization.load_pem_private_key(self.private_key_data.encode(), private_key_password, backend=default_backend()) return pkey.private_bytes(encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption())
[docs] def get_column_definitions(self, schema, table): """ Return a list of column name, column data type pairs """ columns = self.fetch('DESCRIBE TABLE {}.{} TYPE=COLUMNS'.format(schema, table)) return [column[0:2] for column in columns]