#!/usr/bin/env python

##
 # Copyright (c) 2013-2015, Roland Bock
 # All rights reserved.
 #
 # Redistribution and use in source and binary forms, with or without modification,
 # are permitted provided that the following conditions are met:
 #
 #  * Redistributions of source code must retain the above copyright notice,
 #    this list of conditions and the following disclaimer.
 #  * Redistributions in binary form must reproduce the above copyright notice,
 #    this list of conditions and the following disclaimer in the documentation
 #    and/or other materials provided with the distribution.
 #
 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
 # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
 # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 # OF THE POSSIBILITY OF SUCH DAMAGE.
 ##

from __future__ import print_function
import sys
import re
import os

from pyparsing import CaselessLiteral, Literal, SkipTo, restOfLine, oneOf, ZeroOrMore, Optional, Combine, \
    WordStart, WordEnd, Word, alphas, alphanums, nums, QuotedString, nestedExpr, MatchFirst, OneOrMore, delimitedList, Or, Group

INCLUDE = 'sqlpp11'
NAMESPACE = 'sqlpp'

# HELPERS
def get_include_guard_name(namespace, inputfile):
  val = re.sub("[^A-Za-z]+", "_", namespace + '_' + os.path.basename(inputfile))
  return val.upper()

def repl_func(m):
  if (m.group(1) == '_'):
    return m.group(2).upper()
  else:
    return m.group(1) + m.group(2).upper()

def toClassName(s):
  return re.sub("(^|\s|[_0-9])(\S)", repl_func, s)

def toMemberName(s):
  return re.sub("(\s|_|[0-9])(\S)", repl_func, s)


# PARSER
def ddlWord(string):
    return WordStart(alphanums + "_") + CaselessLiteral(string) + WordEnd(alphanums + "_")

ddlString   = Or([QuotedString("'"), QuotedString("\"", escQuote='""'), QuotedString("`")])
negativeSign = Literal('-')
ddlNum     = Combine(Optional(negativeSign) + Word(nums + "."))
ddlTerm     = Word(alphas, alphanums + "_$")
ddlArguments = "(" + delimitedList(Or([ddlString, ddlTerm, ddlNum])) + ")"
ddlNotNull = Group(ddlWord("NOT") + ddlWord("NULL")).setResultsName("notNull")
ddlDefaultValue = ddlWord("DEFAULT").setResultsName("hasDefaultValue");
ddlAutoValue = ddlWord("AUTO_INCREMENT").setResultsName("hasAutoValue");
ddlColumnComment  = Group(ddlWord("COMMENT") + ddlString).setResultsName("comment")
ddlConstraint = Or([
        ddlWord("CONSTRAINT"),
        ddlWord("PRIMARY"),
        ddlWord("FOREIGN"),
        ddlWord("KEY"),
        ddlWord("INDEX"),
        ddlWord("UNIQUE"),
        ])
ddlColumn   = Group(Optional(ddlConstraint).setResultsName("isConstraint") + OneOrMore(MatchFirst([ddlNotNull, ddlAutoValue, ddlDefaultValue, ddlTerm, ddlNum, ddlColumnComment, ddlString, ddlArguments])))
createTable = Group(ddlWord("CREATE") + ddlWord("TABLE") + ddlTerm.setResultsName("tableName") + "(" + Group(delimitedList(ddlColumn)).setResultsName("columns") + ")").setResultsName("create")


ddl = ZeroOrMore(SkipTo(createTable, True))

ddlComment = oneOf(["--", "#"]) + restOfLine
ddl.ignore(ddlComment)

# MAP SQL TYPES
types = {
    'tinyint': 'tinyint',
    'smallint': 'smallint',
    'integer': 'integer',
    'int': 'integer',
    'bigint': 'bigint',
    'char': 'char_',
    'varchar': 'varchar',
    'text': 'text',
    'tinyblob': 'blob',
    'blob': 'blob',
    'mediumblob': 'blob',
    'longblob': 'blob',
    'bool': 'boolean',
    'boolean': 'boolean',
    'double': 'floating_point',
    'float': 'floating_point',
    'date' : 'day_point',
    'datetime' : 'time_point',
    }

# PROCESS DDL
if (len(sys.argv) != 4):
  print('Usage: ddl2cpp <path to ddl> <path to target (without extension, e.g. /tmp/MyTable)> <namespace>')
  sys.exit(1)

pathToDdl = sys.argv[1]
pathToHeader = sys.argv[2] + '.h'
namespace = sys.argv[3]
ddlFile = open(pathToDdl, 'r')
header = open(pathToHeader, 'w')

print('// generated by ' + ' '.join(sys.argv), file=header)
print('#ifndef '+get_include_guard_name(namespace, pathToHeader), file=header)
print('#define '+get_include_guard_name(namespace, pathToHeader), file=header)
print('', file=header)
print('#include <' + INCLUDE + '/table.h>', file=header)
print('#include <' + INCLUDE + '/data_types.h>', file=header)
print('#include <' + INCLUDE + '/char_sequence.h>', file=header)
print('', file=header)
print('namespace ' + namespace, file=header)
print('{', file=header)

tableCreations = ddl.parseFile(pathToDdl)

for tableCreation in tableCreations:
    sqlTableName = tableCreation.create.tableName
    tableClass = toClassName(sqlTableName)
    tableMember = toMemberName(sqlTableName)
    tableNamespace = tableClass + '_'
    tableTemplateParameters = tableClass
    print('  namespace ' + tableNamespace, file=header)
    print('  {', file=header)
    for column in tableCreation.create.columns:
        if column.isConstraint:
            continue
        sqlColumnName = column[0]
        columnClass = toClassName(sqlColumnName)
        tableTemplateParameters += ',\n               ' + tableNamespace + '::' + columnClass
        columnMember = toMemberName(sqlColumnName)
        sqlColumnType = column[1].lower()
        columnCanBeNull = not column.notNull
        print('    struct ' + columnClass, file=header)
        print('    {', file=header)
        print('      struct _alias_t', file=header)
        print('      {', file=header)
        print('        static constexpr const char _literal[] =  "' + sqlColumnName + '";', file=header)
        print('        using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;', file=header)
        print('        template<typename T>', file=header)
        print('        struct _member_t', file=header)
        print('          {', file=header)
        print('            T ' + columnMember + ';', file=header)
        print('            T& operator()() { return ' + columnMember + '; }', file=header)
        print('            const T& operator()() const { return ' + columnMember + '; }', file=header)
        print('          };', file=header)
        print('      };', file=header)
        traitslist = [NAMESPACE + '::' + types[sqlColumnType]];
        requireInsert = True
        if column.hasAutoValue:
            traitslist.append(NAMESPACE + '::tag::must_not_insert');
            traitslist.append(NAMESPACE + '::tag::must_not_update');
            requireInsert = False
        if not column.notNull:
            traitslist.append(NAMESPACE + '::tag::can_be_null');
            requireInsert = False
        if column.hasDefaultValue:
            requireInsert = False
        if requireInsert:
            traitslist.append(NAMESPACE + '::tag::require_insert');
        print('      using _traits = ' + NAMESPACE + '::make_traits<' + ', '.join(traitslist) + '>;', file=header)
        print('    };', file=header)
    print('  }', file=header)
    print('', file=header)

    print('  struct ' + tableClass + ': ' + NAMESPACE + '::table_t<' + tableTemplateParameters + '>', file=header)
    print('  {', file=header)
    print('    struct _alias_t', file=header)
    print('    {', file=header)
    print('      static constexpr const char _literal[] =  "' + sqlTableName + '";', file=header)
    print('      using _name_t = sqlpp::make_char_sequence<sizeof(_literal), _literal>;', file=header)
    print('      template<typename T>', file=header)
    print('      struct _member_t', file=header)
    print('      {', file=header)
    print('        T ' + tableMember + ';', file=header)
    print('        T& operator()() { return ' + tableMember + '; }', file=header)
    print('        const T& operator()() const { return ' + tableMember + '; }', file=header)
    print('      };', file=header)
    print('    };', file=header)
    print('  };', file=header)

print('}', file=header)
print('#endif', file=header)

