Skip to content

Commit 68db859

Browse files
authored
refactor(ingest): streamline two-tier db config validation (#5986)
1 parent b638bcf commit 68db859

File tree

4 files changed

+57
-35
lines changed

4 files changed

+57
-35
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import warnings
2+
from typing import Callable, Type, TypeVar
3+
4+
import pydantic
5+
6+
_T = TypeVar("_T")
7+
8+
9+
def _default_rename_transform(value: _T) -> _T:
10+
return value
11+
12+
13+
def pydantic_renamed_field(
14+
old_name: str,
15+
new_name: str,
16+
transform: Callable[[_T], _T] = _default_rename_transform,
17+
) -> classmethod:
18+
def _validate_field_rename(cls: Type, values: dict) -> dict:
19+
if old_name in values:
20+
if new_name in values:
21+
raise ValueError(
22+
f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}."
23+
)
24+
else:
25+
warnings.warn(
26+
f"The {old_name} is deprecated, please use {new_name} instead.",
27+
UserWarning,
28+
)
29+
values[new_name] = transform(values.pop(old_name))
30+
return values
31+
32+
# Why aren't we using pydantic.validator here?
33+
# The `values` argument that is passed to field validators only contains items
34+
# that have already been validated in the pre-process phase, which happens if
35+
# they have an associated field and a pre=True validator. However, the root
36+
# validator with pre=True gets all the values that were passed in.
37+
# Given that a renamed field doesn't show up in the fields list, we can't use
38+
# the field-level validator, even with a different field name.
39+
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)

metadata-ingestion/src/datahub/ingestion/source/sql/two_tier_sql_source.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import typing
2-
from typing import Any, Dict
32

4-
import pydantic
53
from pydantic.fields import Field
64
from sqlalchemy import create_engine, inspect
75
from sqlalchemy.engine.reflection import Inspector
86

97
from datahub.configuration.common import AllowDenyPattern
8+
from datahub.configuration.validate_field_rename import pydantic_renamed_field
109
from datahub.emitter.mcp_builder import PlatformKey
1110
from datahub.ingestion.api.workunit import MetadataWorkUnit
1211
from datahub.ingestion.source.sql.sql_common import (
@@ -24,40 +23,26 @@ class TwoTierSQLAlchemyConfig(BasicSQLAlchemyConfig):
2423
description="Regex patterns for databases to filter in ingestion.",
2524
)
2625
schema_pattern: AllowDenyPattern = Field(
26+
# The superclass contains a `schema_pattern` field, so we need this here
27+
# to override the documentation.
2728
default=AllowDenyPattern.allow_all(),
28-
description="Deprecated in favour of database_pattern. Regex patterns for schemas to filter in ingestion. "
29-
"Specify regex to only match the schema name. e.g. to match all tables in schema analytics, "
30-
"use the regex 'analytics'",
29+
description="Deprecated in favour of database_pattern.",
3130
)
3231

33-
@pydantic.root_validator()
34-
def ensure_profiling_pattern_is_passed_to_profiling(
35-
cls, values: Dict[str, Any]
36-
) -> Dict[str, Any]:
37-
allow_all_pattern = AllowDenyPattern.allow_all()
38-
schema_pattern = values.get("schema_pattern")
39-
database_pattern = values.get("database_pattern")
40-
if (
41-
database_pattern == allow_all_pattern
42-
and schema_pattern != allow_all_pattern
43-
):
44-
logger.warning(
45-
"Updating 'database_pattern' to 'schema_pattern'. Please stop using deprecated "
46-
"'schema_pattern'. Use 'database_pattern' instead. "
47-
)
48-
values["database_pattern"] = schema_pattern
49-
return values
32+
_schema_pattern_deprecated = pydantic_renamed_field(
33+
"schema_pattern", "database_pattern"
34+
)
5035

5136
def get_sql_alchemy_url(
5237
self,
5338
uri_opts: typing.Optional[typing.Dict[str, typing.Any]] = None,
5439
current_db: typing.Optional[str] = None,
5540
) -> str:
5641
return self.sqlalchemy_uri or make_sqlalchemy_uri(
57-
self.scheme, # type: ignore
42+
self.scheme,
5843
self.username,
5944
self.password.get_secret_value() if self.password else None,
60-
self.host_port, # type: ignore
45+
self.host_port,
6146
current_db if current_db else self.database,
6247
uri_opts=uri_opts,
6348
)
@@ -70,6 +55,8 @@ def __init__(self, config, ctx, platform):
7055
self.config: TwoTierSQLAlchemyConfig = config
7156

7257
def get_parent_container_key(self, db_name: str, schema: str) -> PlatformKey:
58+
# Because our overridden get_allowed_schemas method returns db_name as the schema name,
59+
# the db_name and schema here will be the same. Hence, we just ignore the schema parameter.
7360
return self.gen_database_key(db_name)
7461

7562
def get_allowed_schemas(

metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
JobId,
1515
JobStateKey,
1616
)
17-
from datahub.ingestion.source.sql.mysql import MySQLConfig
17+
from datahub.ingestion.source.sql.postgres import PostgresConfig
1818
from datahub.ingestion.source.state.checkpoint import Checkpoint
1919
from datahub.ingestion.source.state.sql_common_state import (
2020
BaseSQLAlchemyCheckpointState,
@@ -124,7 +124,7 @@ def test_provider(self):
124124
pipeline_name=self.pipeline_name,
125125
platform_instance_id=self.platform_instance_id,
126126
run_id=self.run_id,
127-
config=MySQLConfig(),
127+
config=PostgresConfig(host_port="localhost:5432"),
128128
state=job1_state_obj,
129129
)
130130
# Job2 - Checkpoint with a BaseUsageCheckpointState state
@@ -136,22 +136,18 @@ def test_provider(self):
136136
pipeline_name=self.pipeline_name,
137137
platform_instance_id=self.platform_instance_id,
138138
run_id=self.run_id,
139-
config=MySQLConfig(),
139+
config=PostgresConfig(host_port="localhost:5432"),
140140
state=job2_state_obj,
141141
)
142142

143143
# 2. Set the provider's state_to_commit.
144144
self.provider.state_to_commit = {
145145
# NOTE: state_to_commit accepts only the aspect version of the checkpoint.
146146
self.job_names[0]: job1_checkpoint.to_checkpoint_aspect(
147-
# fmt: off
148147
max_allowed_state_size=2**20
149-
# fmt: on
150148
),
151149
self.job_names[1]: job2_checkpoint.to_checkpoint_aspect(
152-
# fmt: off
153150
max_allowed_state_size=2**20
154-
# fmt: on
155151
),
156152
}
157153

metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from datahub.emitter.mce_builder import make_dataset_urn
7-
from datahub.ingestion.source.sql.mysql import MySQLConfig
7+
from datahub.ingestion.source.sql.postgres import PostgresConfig
88
from datahub.ingestion.source.sql.sql_common import BasicSQLAlchemyConfig
99
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
1010
from datahub.ingestion.source.state.sql_common_state import (
@@ -21,7 +21,7 @@
2121
test_platform_instance_id: str = "test_platform_instance_1"
2222
test_job_name: str = "test_job_1"
2323
test_run_id: str = "test_run_1"
24-
test_source_config: BasicSQLAlchemyConfig = MySQLConfig()
24+
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")
2525

2626
# 2. Create the params for parametrized tests.
2727

@@ -79,7 +79,7 @@ def test_create_from_checkpoint_aspect(state_obj):
7979
job_name=test_job_name,
8080
checkpoint_aspect=checkpoint_aspect,
8181
state_class=type(state_obj),
82-
config_class=MySQLConfig,
82+
config_class=PostgresConfig,
8383
)
8484

8585
expected_checkpoint_obj = Checkpoint(
@@ -125,6 +125,6 @@ def test_serde_idempotence(state_obj):
125125
job_name=test_job_name,
126126
checkpoint_aspect=checkpoint_aspect,
127127
state_class=type(state_obj),
128-
config_class=MySQLConfig,
128+
config_class=PostgresConfig,
129129
)
130130
assert orig_checkpoint_obj == serde_checkpoint_obj

0 commit comments

Comments
 (0)