Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import warnings
from typing import Callable, Type, TypeVar

import pydantic

_T = TypeVar("_T")


def _default_rename_transform(value: _T) -> _T:
return value


def pydantic_renamed_field(
old_name: str,
new_name: str,
transform: Callable[[_T], _T] = _default_rename_transform,
) -> classmethod:
def _validate_field_rename(cls: Type, values: dict) -> dict:
if old_name in values:
if new_name in values:
raise ValueError(
f"Cannot specify both {old_name} and {new_name} in the same config. Note that {old_name} has been deprecated in favor of {new_name}."
)
else:
warnings.warn(
f"The {old_name} is deprecated, please use {new_name} instead.",
UserWarning,
)
values[new_name] = transform(values.pop(old_name))
return values

# Why aren't we using pydantic.validator here?
# The `values` argument that is passed to field validators only contains items
# that have already been validated in the pre-process phase, which happens if
# they have an associated field and a pre=True validator. However, the root
# validator with pre=True gets all the values that were passed in.
# Given that a renamed field doesn't show up in the fields list, we can't use
# the field-level validator, even with a different field name.
return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename)
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import typing
from typing import Any, Dict

import pydantic
from pydantic.fields import Field
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.emitter.mcp_builder import PlatformKey
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.sql.sql_common import (
Expand All @@ -24,40 +23,26 @@ class TwoTierSQLAlchemyConfig(BasicSQLAlchemyConfig):
description="Regex patterns for databases to filter in ingestion.",
)
schema_pattern: AllowDenyPattern = Field(
# The superclass contains a `schema_pattern` field, so we need this here
# to override the documentation.
default=AllowDenyPattern.allow_all(),
description="Deprecated in favour of database_pattern. Regex patterns for schemas to filter in ingestion. "
"Specify regex to only match the schema name. e.g. to match all tables in schema analytics, "
"use the regex 'analytics'",
description="Deprecated in favour of database_pattern.",
)

@pydantic.root_validator()
def ensure_profiling_pattern_is_passed_to_profiling(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
allow_all_pattern = AllowDenyPattern.allow_all()
schema_pattern = values.get("schema_pattern")
database_pattern = values.get("database_pattern")
if (
database_pattern == allow_all_pattern
and schema_pattern != allow_all_pattern
):
logger.warning(
"Updating 'database_pattern' to 'schema_pattern'. Please stop using deprecated "
"'schema_pattern'. Use 'database_pattern' instead. "
)
values["database_pattern"] = schema_pattern
return values
_schema_pattern_deprecated = pydantic_renamed_field(
"schema_pattern", "database_pattern"
)

def get_sql_alchemy_url(
self,
uri_opts: typing.Optional[typing.Dict[str, typing.Any]] = None,
current_db: typing.Optional[str] = None,
) -> str:
return self.sqlalchemy_uri or make_sqlalchemy_uri(
self.scheme, # type: ignore
self.scheme,
self.username,
self.password.get_secret_value() if self.password else None,
self.host_port, # type: ignore
self.host_port,
current_db if current_db else self.database,
uri_opts=uri_opts,
)
Expand All @@ -70,6 +55,8 @@ def __init__(self, config, ctx, platform):
self.config: TwoTierSQLAlchemyConfig = config

def get_parent_container_key(self, db_name: str, schema: str) -> PlatformKey:
# Because our overridden get_allowed_schemas method returns db_name as the schema name,
# the db_name and schema here will be the same. Hence, we just ignore the schema parameter.
return self.gen_database_key(db_name)

def get_allowed_schemas(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
JobId,
JobStateKey,
)
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_provider(self):
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
config=MySQLConfig(),
config=PostgresConfig(host_port="localhost:5432"),
state=job1_state_obj,
)
# Job2 - Checkpoint with a BaseUsageCheckpointState state
Expand All @@ -136,22 +136,18 @@ def test_provider(self):
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
config=MySQLConfig(),
config=PostgresConfig(host_port="localhost:5432"),
state=job2_state_obj,
)

# 2. Set the provider's state_to_commit.
self.provider.state_to_commit = {
# NOTE: state_to_commit accepts only the aspect version of the checkpoint.
self.job_names[0]: job1_checkpoint.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
),
self.job_names[1]: job2_checkpoint.to_checkpoint_aspect(
# fmt: off
max_allowed_state_size=2**20
# fmt: on
),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from datahub.emitter.mce_builder import make_dataset_urn
from datahub.ingestion.source.sql.mysql import MySQLConfig
from datahub.ingestion.source.sql.postgres import PostgresConfig
from datahub.ingestion.source.sql.sql_common import BasicSQLAlchemyConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase
from datahub.ingestion.source.state.sql_common_state import (
Expand All @@ -21,7 +21,7 @@
test_platform_instance_id: str = "test_platform_instance_1"
test_job_name: str = "test_job_1"
test_run_id: str = "test_run_1"
test_source_config: BasicSQLAlchemyConfig = MySQLConfig()
test_source_config: BasicSQLAlchemyConfig = PostgresConfig(host_port="test_host:1234")

# 2. Create the params for parametrized tests.

Expand Down Expand Up @@ -79,7 +79,7 @@ def test_create_from_checkpoint_aspect(state_obj):
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
config_class=MySQLConfig,
config_class=PostgresConfig,
)

expected_checkpoint_obj = Checkpoint(
Expand Down Expand Up @@ -125,6 +125,6 @@ def test_serde_idempotence(state_obj):
job_name=test_job_name,
checkpoint_aspect=checkpoint_aspect,
state_class=type(state_obj),
config_class=MySQLConfig,
config_class=PostgresConfig,
)
assert orig_checkpoint_obj == serde_checkpoint_obj