Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ func processMessage[Items model.Items](
if err != nil {
return nil, fmt.Errorf("error parsing logical message: %w", err)
}
customTypeMapping, err := p.fetchCustomTypeMapping(ctx)
customTypeMapping, err := p.fetchCustomTypeMapping(ctx, p.internalVersion)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1005,7 +1005,7 @@ func processRelationMessage[Items model.Items](
slog.Uint64("relId", uint64(currRel.RelationID)))
return nil, nil
}
customTypeMapping, err := p.fetchCustomTypeMapping(ctx)
customTypeMapping, err := p.fetchCustomTypeMapping(ctx, p.internalVersion)
if err != nil {
return nil, err
}
Expand Down
142 changes: 131 additions & 11 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/jackc/pgerrcode"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
Expand Down Expand Up @@ -140,13 +141,29 @@ func ParseConfig(connectionString string, pgConfig *protos.PostgresConfig) (*pgx
return connConfig, nil
}

func (c *PostgresConnector) fetchCustomTypeMapping(ctx context.Context) (map[uint32]shared.CustomDataType, error) {
func (c *PostgresConnector) fetchCustomTypeMapping(ctx context.Context, version uint32) (map[uint32]shared.CustomDataType, error) {
if c.customTypeMapping == nil {
customTypeMapping, err := shared.GetCustomDataTypes(ctx, c.conn)
if err != nil {
return nil, err
}
c.customTypeMapping = customTypeMapping

if version >= shared.InternalVersion_CompositeTypeAsTuple {
var compositeTypeNames []string
for _, typeData := range customTypeMapping {
if typeData.Type == 'c' && typeData.Delim == 0 { // Only composite types
compositeTypeNames = append(compositeTypeNames, typeData.Name)
}
}
types, err := c.conn.LoadTypes(ctx, compositeTypeNames)
if err != nil {
c.logger.Error("failed to load composite types",
slog.Any("error", err), slog.Any("composite_type_names", compositeTypeNames))
return nil, fmt.Errorf("failed to load composite types: %w", err)
}
c.typeMap.RegisterTypes(types)
}
}
return c.customTypeMapping, nil
}
Expand Down Expand Up @@ -923,6 +940,69 @@ func (c *PostgresConnector) GetSelectedColumns(
return columns, nil
}

func (c *PostgresConnector) getCompositeTypeDetails(ctx context.Context, system protos.TypeSystem, version uint32,
customTypeMapping map[uint32]shared.CustomDataType, oid uint32,
) ([]*protos.FieldDescription, error) {
result := make([]*protos.FieldDescription, 0)
subfields, err := shared.GetCompositeDataTypeDetails(ctx, c.conn, oid)
if err != nil {
return nil, fmt.Errorf("error getting composite data type details for %d: %w", oid, err)
}
for _, subfield := range subfields {
var subColType string
var err error
switch system {
case protos.TypeSystem_PG:
subColType, err = c.postgresOIDToName(subfield.Type.OID, customTypeMapping)
case protos.TypeSystem_Q:
qColType := c.postgresOIDToQValueKind(subfield.Type.OID, customTypeMapping, version)
subColType = string(qColType)
}

if err != nil {
return nil, fmt.Errorf("error getting type name for subfield %d: %w", subfield.Type.OID, err)
}
subCompositeFields := make([]*protos.FieldDescription, 0)
if subColType == "composite" {
subCompositeFields, err = c.getCompositeTypeDetails(ctx, system, version, customTypeMapping, subfield.Type.OID)
if err != nil {
return nil, fmt.Errorf("error getting composite type details for subfield %d: %w", subfield.Type.OID, err)
}
} else if subColType == "array_composite" {
slog.Info("array composite type detected, fetching element type details",
slog.Any("subfieldTypeOID", subfield.Type.OID),
slog.String("subfieldName", subfield.Name),
)
var elemOID uint32
err := c.conn.QueryRow(ctx,
`select typelem from pg_type where oid = $1`, subfield.Type.OID).Scan(&elemOID)
if err != nil {
return nil, fmt.Errorf("error getting array element type OID for %d: %w", subfield.Type.OID, err)
}
elemTypeDetails, err := c.getCompositeTypeDetails(ctx, system, version, customTypeMapping, elemOID)

subCompositeFields = append(subCompositeFields, &protos.FieldDescription{
Name: subfield.Name,
Type: "composite",
TypeModifier: -1,
Nullable: false,
Composite: elemTypeDetails,
})
if err != nil {
return nil, fmt.Errorf("error getting composite type details for array element %d: %w", elemOID, err)
}
}
result = append(result, &protos.FieldDescription{
Name: subfield.Name,
Type: subColType,
TypeModifier: subfield.TypeModifier,
Nullable: !subfield.NotNull,
Composite: subCompositeFields,
})
}
return result, nil
}

func (c *PostgresConnector) getTableSchemaForTable(
ctx context.Context,
env map[string]string,
Expand All @@ -934,7 +1014,7 @@ func (c *PostgresConnector) getTableSchemaForTable(
if err != nil {
return nil, err
}
customTypeMapping, err := c.fetchCustomTypeMapping(ctx)
customTypeMapping, err := c.fetchCustomTypeMapping(ctx, version)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -989,12 +1069,21 @@ func (c *PostgresConnector) getTableSchemaForTable(
if err != nil {
return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err)
}
defer rows.Close()

fields := rows.FieldDescriptions()
columnNames := make([]string, 0, len(fields))
columns := make([]*protos.FieldDescription, 0, len(fields))
for _, fieldDescription := range fields {
var fieldsCopy []pgconn.FieldDescription
{
fields := rows.FieldDescriptions()
fieldsCopy = make([]pgconn.FieldDescription, len(fields))
copy(fieldsCopy, fields)
}

rows.Close()
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over table schema: %w", err)
}
columnNames := make([]string, 0, len(fieldsCopy))
columns := make([]*protos.FieldDescription, 0, len(fieldsCopy))
for _, fieldDescription := range fieldsCopy {
var colType string
var err error
switch system {
Expand All @@ -1008,19 +1097,50 @@ func (c *PostgresConnector) getTableSchemaForTable(
return nil, fmt.Errorf("error getting type name for %d: %w", fieldDescription.DataTypeOID, err)
}

composite := make([]*protos.FieldDescription, 0)
if colType == "composite" {
subtypes, err := c.getCompositeTypeDetails(ctx, system, version, customTypeMapping, fieldDescription.DataTypeOID)
if err != nil {
return nil, fmt.Errorf("error getting composite type details for %d: %w", fieldDescription.DataTypeOID, err)
}
composite = subtypes
}
if colType == "array_composite" {
slog.Info("array composite type detected, fetching element type details",
slog.Any("subfieldTypeOID", fieldDescription.DataTypeOID),
slog.String("subfieldName", fieldDescription.Name),
)
var elemOID uint32
err := c.conn.QueryRow(ctx,
`select typelem from pg_type where oid = $1`, fieldDescription.DataTypeOID).Scan(&elemOID)
if err != nil {
return nil, fmt.Errorf("error getting array element type OID for %d: %w", fieldDescription.DataTypeOID, err)
}
elemTypeDetails, err := c.getCompositeTypeDetails(ctx, system, version, customTypeMapping, elemOID)

composite = append(composite, &protos.FieldDescription{
Name: fieldDescription.Name,
Type: "composite",
TypeModifier: -1,
Nullable: false,
Composite: elemTypeDetails,
})
if err != nil {
return nil, fmt.Errorf("error getting composite type details for array element %d: %w", elemOID, err)
}
}

columnNames = append(columnNames, fieldDescription.Name)
slog.Info("fieldDescription", slog.String("name", fieldDescription.Name))
_, nullable := nullableCols[fieldDescription.Name]
columns = append(columns, &protos.FieldDescription{
Name: fieldDescription.Name,
Type: colType,
TypeModifier: fieldDescription.TypeModifier,
Nullable: nullable,
Composite: composite,
})
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over table schema: %w", err)
}
// if we have no pkey, we will use all columns as the pkey for the MERGE statement
if replicaIdentityType == ReplicaIdentityFull && len(pKeyCols) == 0 {
pKeyCols = columnNames
Expand Down
90 changes: 71 additions & 19 deletions flow/connectors/postgres/qrep_query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (c *PostgresConnector) NewQRepQueryExecutor(ctx context.Context, version ui
func (c *PostgresConnector) NewQRepQueryExecutorSnapshot(ctx context.Context, version uint32,
snapshot string, flowJobName string, partitionID string,
) (*QRepQueryExecutor, error) {
if _, err := c.fetchCustomTypeMapping(ctx); err != nil {
if _, err := c.fetchCustomTypeMapping(ctx, version); err != nil {
c.logger.Error("[pg_query_executor] failed to fetch custom type mapping", slog.Any("error", err))
return nil, fmt.Errorf("failed to fetch custom type mapping: %w", err)
}
Expand Down Expand Up @@ -70,28 +70,80 @@ func (qe *QRepQueryExecutor) executeQueryInTx(ctx context.Context, tx pgx.Tx, cu
return rows, nil
}

// buildQFieldFromOID recursively builds a QField from a PostgreSQL OID, handling nested composite types
func (qe *QRepQueryExecutor) buildQFieldFromOID(name string, oid uint32, typeModifier int32) types.QField {
ctype := qe.postgresOIDToQValueKind(oid, qe.customTypeMapping, qe.version)

if ctype == types.QValueKindNumeric || ctype == types.QValueKindArrayNumeric {
precision, scale := datatypes.ParseNumericTypmod(typeModifier)
return types.QField{
Name: name,
Type: ctype,
Nullable: true,
Precision: precision,
Scale: scale,
}
} else if ctype == types.QValueKindComposite {
if typ, ok := qe.typeMap.TypeForOID(oid); ok {
if cc, ok := typ.Codec.(*pgtype.CompositeCodec); ok {
subQFields := make([]*types.QField, 0)
for _, f := range cc.Fields {
// Recursively build subfields, handling nested composite types
subField := qe.buildQFieldFromOID(f.Name, f.Type.OID, -1)
subQFields = append(subQFields, &subField)
}
return types.QField{
Name: name,
Type: ctype,
Nullable: true,
SubFields: subQFields,
}
}
}
qe.logger.Error("[pg_query_executor] type not found for oid or not a composite type",
slog.Uint64("type_oid", uint64(oid)),
slog.String("type_name", name))
return types.QField{
Name: name,
Type: ctype,
Nullable: true,
}
} else if ctype == types.QValueKindArrayComposite {
if typ, ok := qe.typeMap.TypeForOID(oid); ok {
if ac, ok := typ.Codec.(*pgtype.ArrayCodec); ok {
subQFields := make([]*types.QField, 0)
subField := qe.buildQFieldFromOID(ac.ElementType.Name, ac.ElementType.OID, -1)
subQFields = append(subQFields, &subField)
return types.QField{
Name: name,
Type: ctype,
Nullable: true,
SubFields: subQFields,
}
}
}
qe.logger.Error("[pg_query_executor] type not found for oid or not an array composite type",
slog.Uint64("type_oid", uint64(oid)),
slog.String("type_name", name))
return types.QField{
Name: name,
Type: ctype,
Nullable: true,
}
} else {
return types.QField{
Name: name,
Type: ctype,
Nullable: true,
}
}
}

// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema.
func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescription) types.QRecordSchema {
qfields := make([]types.QField, len(fds))
for i, fd := range fds {
ctype := qe.postgresOIDToQValueKind(fd.DataTypeOID, qe.customTypeMapping, qe.version)
// there isn't a way to know if a column is nullable or not
if ctype == types.QValueKindNumeric || ctype == types.QValueKindArrayNumeric {
precision, scale := datatypes.ParseNumericTypmod(fd.TypeModifier)
qfields[i] = types.QField{
Name: fd.Name,
Type: ctype,
Nullable: true,
Precision: precision,
Scale: scale,
}
} else {
qfields[i] = types.QField{
Name: fd.Name,
Type: ctype,
Nullable: true,
}
}
qfields[i] = qe.buildQFieldFromOID(fd.Name, fd.DataTypeOID, fd.TypeModifier)
}
return types.NewQRecordSchema(qfields)
}
Expand Down
Loading
Loading