#!/usr/bin/env python

##
 # Copyright (c) 2013-2014, Roland Bock
 # All rights reserved.
 # 
 # Redistribution and use in source and binary forms, with or without modification, 
 # are permitted provided that the following conditions are met:
 # 
 #  * Redistributions of source code must retain the above copyright notice, 
 #    this list of conditions and the following disclaimer.
 #  * Redistributions in binary form must reproduce the above copyright notice, 
 #    this list of conditions and the following disclaimer in the documentation 
 #    and/or other materials provided with the distribution.
 # 
 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 
 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 
 # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 
 # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 
 # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 
 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 
 # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 
 # OF THE POSSIBILITY OF SUCH DAMAGE.
 ##

from __future__ import print_function
import sys
import re
import os

from pyparsing import CaselessLiteral, SkipTo, restOfLine, oneOf, ZeroOrMore, Optional, \
    WordStart, WordEnd, Word, alphas, alphanums, nums, QuotedString, nestedExpr, MatchFirst, OneOrMore, delimitedList, Or, Group

INCLUDE = 'sqlpp11'
NAMESPACE = 'sqlpp'

# HELPERS
def get_include_guard_name(namespace, inputfile):
  val = re.sub("[^A-Za-z]+", "_", namespace + '_' + os.path.basename(inputfile))
  return val.upper()

def repl_func(m):
  if (m.group(1) == '_'):
    return m.group(2).upper()
  else:
    return m.group(1) + m.group(2).upper()

def toClassName(s):
  return re.sub("(^|\s|[_0-9])(\S)", repl_func, s)

def toMemberName(s):
  return re.sub("(\s|_|[0-9])(\S)", repl_func, s)


# PARSER
def ddlWord(string):
    return WordStart(alphanums + "_") + CaselessLiteral(string) + WordEnd(alphanums + "_")

ddlString   = Or([QuotedString("'"), QuotedString("\"", escQuote='""'), QuotedString("`")])
ddlNum     = Word(nums + ".")
ddlTerm     = Word(alphas, alphanums + "_$")
ddlArguments = "(" + delimitedList(Or([ddlString, ddlTerm, ddlNum])) + ")"
ddlNotNull = Group(ddlWord("NOT") + ddlWord("NULL")).setResultsName("notNull")
ddlDefaultValue = ddlWord("DEFAULT").setResultsName("hasDefaultValue");
ddlAutoValue = ddlWord("AUTO_INCREMENT").setResultsName("hasAutoValue");
ddlColumnComment  = Group(ddlWord("COMMENT") + ddlString).setResultsName("comment")
ddlConstraint = Or([
        ddlWord("CONSTRAINT"), 
        ddlWord("PRIMARY"), 
        ddlWord("FOREIGN"), 
        ddlWord("KEY"),
        ddlWord("INDEX"), 
        ddlWord("UNIQUE"), 
        ])
ddlColumn   = Group(Optional(ddlConstraint).setResultsName("isConstraint") + OneOrMore(MatchFirst([ddlNotNull, ddlAutoValue, ddlDefaultValue, ddlTerm, ddlNum, ddlColumnComment, ddlString, ddlArguments])))
createTable = Group(ddlWord("CREATE") + ddlWord("TABLE") + ddlTerm.setResultsName("tableName") + "(" + Group(delimitedList(ddlColumn)).setResultsName("columns") + ")").setResultsName("create")


ddl = ZeroOrMore(SkipTo(createTable, True))

ddlComment = oneOf(["--", "#"]) + restOfLine
ddl.ignore(ddlComment)

# MAP SQL TYPES
types = {
    'tinyint': 'tinyint',
    'smallint': 'smallint',
    'integer': 'integer',
    'int': 'integer',
    'bigint': 'bigint',
    'char': 'char_',
    'varchar': 'varchar',
    'text': 'text',
    'tinyblob': 'blob',
    'blob': 'blob',
    'mediumblob': 'blob',
    'longblob': 'blob',
    'bool': 'boolean',
    'double': 'floating_point',
    'float': 'floating_point',
    }

# PROCESS DDL
if (len(sys.argv) != 4):
  print('Usage: ddl2cpp <path to ddl> <path to target (without extension, e.g. /tmp/MyTable)> <namespace>')
  sys.exit(1)

pathToDdl = sys.argv[1]
pathToHeader = sys.argv[2] + '.h'
namespace = sys.argv[3]
ddlFile = open(pathToDdl, 'r')
header = open(pathToHeader, 'w')

print('#ifndef '+get_include_guard_name(namespace, pathToHeader), file=header)
print('#define '+get_include_guard_name(namespace, pathToHeader), file=header)
print('', file=header)
print('#include <' + INCLUDE + '/table.h>', file=header)
print('#include <' + INCLUDE + '/column_types.h>', file=header)
print('', file=header)
print('namespace ' + namespace, file=header)
print('{', file=header)

tableCreations = ddl.parseFile(pathToDdl)

for tableCreation in tableCreations:
    sqlTableName = tableCreation.create.tableName
    tableClass = toClassName(sqlTableName)
    tableMember = toMemberName(sqlTableName)
    tableNamespace = tableClass + '_'
    tableTemplateParameters = tableClass
    print('  namespace ' + tableNamespace, file=header)
    print('  {', file=header)
    for column in tableCreation.create.columns:
        if column.isConstraint:
            continue
        sqlColumnName = column[0]
        columnClass = toClassName(sqlColumnName)
        tableTemplateParameters += ',\n               ' + tableNamespace + '::' + columnClass
        columnMember = toMemberName(sqlColumnName)
        sqlColumnType = column[1].lower()
        columnCanBeNull = not column.notNull
        print('    struct ' + columnClass, file=header)
        print('    {', file=header)
        print('      struct _name_t', file=header)
        print('      {', file=header)
        print('        static constexpr const char* _get_name() { return "' + sqlColumnName + '"; }', file=header)
        print('        template<typename T>', file=header)
        print('        struct _member_t', file=header)
        print('          {', file=header)
        print('            T ' + columnMember + ';', file=header)
        print('            T& operator()() { return ' + columnMember + '; }', file=header)
        print('            const T& operator()() const { return ' + columnMember + '; }', file=header)
        print('          };', file=header)
        print('      };', file=header)
        #print(sqlColumnType)
        traitslist = [NAMESPACE + '::' + types[sqlColumnType]];
        requireInsert = True
        if column.hasAutoValue:
            traitslist.append(NAMESPACE + '::tag::must_not_insert');
            traitslist.append(NAMESPACE + '::tag::must_not_update');
            requireInsert = False
        if not column.notNull:
            traitslist.append(NAMESPACE + '::tag::can_be_null');
            requireInsert = False
        if column.hasDefaultValue:
            requireInsert = False
        if requireInsert:
            traitslist.append(NAMESPACE + '::tag::require_insert');
        print('      using _traits = ' + NAMESPACE + '::make_traits<' + ', '.join(traitslist) + '>;', file=header)
        print('    };', file=header)
    print('  }', file=header)
    print('', file=header)

    print('  struct ' + tableClass + ': ' + NAMESPACE + '::table_t<' + tableTemplateParameters + '>', file=header)
    print('  {', file=header)
    print('    struct _name_t', file=header)
    print('    {', file=header)
    print('      static constexpr const char* _get_name() { return "' + sqlTableName + '"; }', file=header)
    print('      template<typename T>', file=header)
    print('      struct _member_t', file=header)
    print('      {', file=header)
    print('        T ' + tableMember + ';', file=header)
    print('        T& operator()() { return ' + tableMember + '; }', file=header)
    print('        const T& operator()() const { return ' + tableMember + '; }', file=header)
    print('      };', file=header)
    print('    };', file=header)
    print('  };', file=header)

print('}', file=header)
print('#endif', file=header)

