diff --git a/src/lib/dbml-import.ts b/src/lib/dbml-import.ts index 2f59d8af..0959094f 100644 --- a/src/lib/dbml-import.ts +++ b/src/lib/dbml-import.ts @@ -28,10 +28,24 @@ interface DBMLField { increment?: boolean; } +interface DBMLIndexColumn { + value: string; + type?: string; + length?: number; + order?: 'asc' | 'desc'; +} + +interface DBMLIndex { + columns: string | (string | DBMLIndexColumn)[]; + unique?: boolean; + name?: string; +} + interface DBMLTable { name: string; schema?: string | { name: string }; fields: DBMLField[]; + indexes?: DBMLIndex[]; } interface DBMLEndpoint { @@ -99,18 +113,60 @@ export const importDBMLToDiagram = async ( // Extract only the necessary data from the parsed DBML const extractedData = { - tables: dbmlData.tables.map((table: DBMLTable) => ({ - name: table.name, - schema: table.schema, - fields: table.fields.map((field: DBMLField) => ({ - name: field.name, - type: field.type, - unique: field.unique, - pk: field.pk, - not_null: field.not_null, - increment: field.increment, - })), - })), + tables: (dbmlData.tables as unknown as DBMLTable[]).map( + (table) => ({ + name: table.name, + schema: table.schema, + fields: table.fields.map((field: DBMLField) => ({ + name: field.name, + type: field.type, + unique: field.unique, + pk: field.pk, + not_null: field.not_null, + increment: field.increment, + })), + indexes: + table.indexes?.map((dbmlIndex) => { + let indexColumns: string[]; + + // Handle composite index case "(col1, col2)" + if (typeof dbmlIndex.columns === 'string') { + if (dbmlIndex.columns.includes('(')) { + // Composite index + const columnsStr = + dbmlIndex.columns.replace(/[()]/g, ''); + indexColumns = columnsStr + .split(',') + .map((c) => c.trim()); + } else { + // Single column + indexColumns = [dbmlIndex.columns.trim()]; + } + } else { + // Handle array of columns + indexColumns = Array.isArray(dbmlIndex.columns) + ? dbmlIndex.columns.map((col) => + typeof col === 'object' && + 'value' in col + ? (col.value as string).trim() + : (col as string).trim() + ) + : [String(dbmlIndex.columns).trim()]; + } + + // Generate a consistent index name + const indexName = + dbmlIndex.name || + `idx_${table.name}_${indexColumns.join('_')}`; + + return { + columns: indexColumns, + unique: dbmlIndex.unique || false, + name: indexName, + }; + }) || [], + }) + ), refs: (dbmlData.refs as unknown as DBMLRef[]).map((ref) => ({ endpoints: (ref.endpoints as [DBMLEndpoint, DBMLEndpoint]).map( (endpoint) => ({ @@ -126,7 +182,42 @@ export const importDBMLToDiagram = async ( const tables: DBTable[] = extractedData.tables.map((table, index) => { const row = Math.floor(index / 4); const col = index % 4; - const tableSpacing = 300; // Increased spacing between tables + const tableSpacing = 300; + + // Create fields first so we have their IDs + const fields = table.fields.map((field) => ({ + id: generateId(), + name: field.name.replace(/['"]/g, ''), + type: mapDBMLTypeToGenericType(field.type.type_name), + nullable: !field.not_null, + primaryKey: field.pk || false, + unique: field.unique || false, + createdAt: Date.now(), + })); + + // Convert DBML indexes to ChartDB indexes + const indexes = + table.indexes?.map((dbmlIndex) => { + const fieldIds = dbmlIndex.columns.map((columnName) => { + const field = fields.find((f) => f.name === columnName); + if (!field) { + throw new Error( + `Index references non-existent column: ${columnName}` + ); + } + return field.id; + }); + + return { + id: generateId(), + name: + dbmlIndex.name || + `idx_${table.name}_${dbmlIndex.columns.join('_')}`, + fieldIds, + unique: dbmlIndex.unique || false, + createdAt: Date.now(), + }; + }) || []; return { id: generateId(), @@ -136,18 +227,10 @@ export const importDBMLToDiagram = async ( ? table.schema : table.schema?.name || '', order: index, - fields: table.fields.map((field) => ({ - id: generateId(), - name: field.name.replace(/['"]/g, ''), - type: mapDBMLTypeToGenericType(field.type.type_name), - nullable: !field.not_null, - primaryKey: field.pk || false, - unique: field.unique || false, - createdAt: Date.now(), - })), + fields, + indexes, x: col * tableSpacing, y: row * tableSpacing, - indexes: [], color: randomColor(), isView: false, createdAt: Date.now(),