diff --git a/frontend/src/pages/DatabaseSchema.tsx b/frontend/src/pages/DatabaseSchema.tsx index e795fc1..2d12c2b 100644 --- a/frontend/src/pages/DatabaseSchema.tsx +++ b/frontend/src/pages/DatabaseSchema.tsx @@ -19,22 +19,21 @@ import { Database as DatabaseIcon, Loader2, Key, Link, RefreshCw } from 'lucide- import { databasesApi, schemaApi, TableInfo, SchemaData } from '@/services/api'; import { Database } from '@/types'; -// Custom node for table +// Calculate column position for handle placement +function getColumnYPosition(columnIndex: number): number { + const headerHeight = 44; + const columnHeight = 26; + return headerHeight + (columnIndex * columnHeight) + columnHeight / 2; +} + +// Custom node for table with per-column handles function TableNode({ data }: { data: TableInfo & Record }) { + // Find which columns are FK sources and which are PK targets + const fkColumns = new Set(data.foreign_keys.map(fk => fk.column)); + const pkColumns = new Set(data.columns.filter(c => c.is_primary).map(c => c.name)); + return (
- {/* Connection handles */} - - - {/* Table header */}
@@ -50,15 +49,37 @@ function TableNode({ data }: { data: TableInfo & Record }) { {/* Columns */}
- {data.columns.map((col) => ( + {data.columns.map((col, index) => (
+ {/* Target handle for PK columns (incoming FK references) */} + {pkColumns.has(col.name) && ( + + )} + + {/* Source handle for FK columns (outgoing references) */} + {fkColumns.has(col.name) && ( + + )} +
{col.is_primary && } - {data.foreign_keys.some(fk => fk.column === col.name) && ( + {fkColumns.has(col.name) && ( )}
@@ -96,33 +117,68 @@ function getNodeHeight(table: TableInfo): number { return headerHeight + (table.columns.length * columnHeight) + fkBarHeight + 8; } +// Calculate connection weight for layout optimization +function calculateTableWeights(schema: SchemaData): Map { + const weights = new Map(); + + schema.tables.forEach(table => { + const key = `${table.schema}.${table.name}`; + // Weight = number of FK connections (both incoming and outgoing) + let weight = table.foreign_keys.length; + + // Add incoming connections + schema.tables.forEach(other => { + other.foreign_keys.forEach(fk => { + if (fk.references_table === table.name) { + weight++; + } + }); + }); + + weights.set(key, weight); + }); + + return weights; +} + function getLayoutedElements(schema: SchemaData): { nodes: Node[]; edges: Edge[] } { const g = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({})); + // Calculate weights for better positioning + const weights = calculateTableWeights(schema); + + // Sort tables by weight (most connected first) for better layout + const sortedTables = [...schema.tables].sort((a, b) => { + const weightA = weights.get(`${a.schema}.${a.name}`) || 0; + const weightB = weights.get(`${b.schema}.${b.name}`) || 0; + return weightB - weightA; + }); + g.setGraph({ rankdir: 'LR', - nodesep: 80, - ranksep: 120, + nodesep: 60, + ranksep: 150, marginx: 50, marginy: 50, + ranker: 'tight-tree', // Better for connected graphs }); // Add nodes - schema.tables.forEach((table) => { + sortedTables.forEach((table) => { const nodeId = `${table.schema}.${table.name}`; const width = 280; const height = getNodeHeight(table); g.setNode(nodeId, { width, height }); }); - // Add edges + // Add edges with weights for layout algorithm schema.tables.forEach((table) => { const sourceId = `${table.schema}.${table.name}`; table.foreign_keys.forEach((fk) => { const targetTable = schema.tables.find(t => t.name === fk.references_table); if (targetTable) { const targetId = `${targetTable.schema}.${targetTable.name}`; - g.setEdge(sourceId, targetId); + g.setEdge(sourceId, targetId, { weight: 2 }); // Higher weight keeps connected nodes closer } }); }); @@ -143,16 +199,15 @@ function getLayoutedElements(schema: SchemaData): { nodes: Node[]; edges: Edge[] y: nodeWithPosition.y - nodeWithPosition.height / 2, }, data: { ...table } as TableInfo & Record, - sourcePosition: Position.Right, - targetPosition: Position.Left, }; }); - // Create edges + // Create edges with column-level handles const edges: Edge[] = []; schema.tables.forEach((table) => { const sourceId = `${table.schema}.${table.name}`; + table.foreign_keys.forEach((fk) => { let targetTable = schema.tables.find(t => t.name === fk.references_table && t.schema === table.schema @@ -163,21 +218,21 @@ function getLayoutedElements(schema: SchemaData): { nodes: Node[]; edges: Edge[] if (targetTable) { const targetId = `${targetTable.schema}.${targetTable.name}`; + edges.push({ id: `${fk.constraint_name}`, source: sourceId, target: targetId, + sourceHandle: `source-${fk.column}`, + targetHandle: `target-${fk.references_column}`, type: 'smoothstep', - animated: true, - style: { stroke: '#3b82f6', strokeWidth: 2 }, + style: { stroke: '#3b82f6', strokeWidth: 1.5 }, markerEnd: { type: MarkerType.ArrowClosed, color: '#3b82f6', + width: 15, + height: 15, }, - label: `${fk.column} → ${fk.references_column}`, - labelStyle: { fontSize: 10, fill: '#666' }, - labelBgStyle: { fill: 'white', fillOpacity: 0.9 }, - labelBgPadding: [4, 2] as [number, number], }); } });