Skip to content
This repository was archived by the owner on Jul 3, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions hamilton/data_quality/default_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd

from hamilton.data_quality import base
from hamilton.data_quality.base import BaseDefaultValidator
from hamilton.data_quality.base import BaseDefaultValidator, ValidationResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -206,9 +206,9 @@ def validate(self, data: pd.Series) -> base.ValidationResult:
return base.ValidationResult(
passes=passes,
message=f'Out of {total_length} items in the series, {total_na} of them are Nan, '
f'representing: {MaxFractionNansValidatorPandasSeries._to_percent(fraction_na)}. '
f'Max allowable Nans is: {MaxFractionNansValidatorPandasSeries._to_percent(self.max_fraction_nans)},'
f' so this {"passes" if passes else "does not pass"}.',
f'representing: {MaxFractionNansValidatorPandasSeries._to_percent(fraction_na)}. '
f'Max allowable Nans is: {MaxFractionNansValidatorPandasSeries._to_percent(self.max_fraction_nans)},'
f' so this {"passes" if passes else "does not pass"}.',
diagnostics={
'total_nan': total_na,
'total_length': total_length,
Expand Down Expand Up @@ -299,7 +299,7 @@ def validate(self, data: Union[numbers.Real, str, bool, int, float, list, dict])
return base.ValidationResult(
passes=passes,
message=f'Requires data type: {self.datatype}. '
f"Got data type: {type(data)}. This {'is' if passes else 'is not'} a match.",
f"Got data type: {type(data)}. This {'is' if passes else 'is not'} a match.",
diagnostics={
'required_data_type': self.datatype,
'actual_data_type': type(data)
Expand Down Expand Up @@ -329,8 +329,8 @@ def validate(self, data: pd.Series) -> base.ValidationResult:
return base.ValidationResult(
passes=passes,
message=f'Max allowable standard dev is: {self.max_standard_dev}. '
f'Dataset stddev is : {standard_dev}. '
f"This {'passes' if passes else 'does not pass'}.",
f'Dataset stddev is : {standard_dev}. '
f"This {'passes' if passes else 'does not pass'}.",
diagnostics={
'standard_dev': standard_dev,
'max_standard_dev': self.max_standard_dev
Expand Down Expand Up @@ -362,7 +362,7 @@ def validate(self, data: pd.Series) -> base.ValidationResult:
return base.ValidationResult(
passes=passes,
message=f"Dataset has mean: {dataset_mean}. This {'is ' if passes else 'is not '} "
f'in the required range: [{self.mean_in_range[0]}, {self.mean_in_range[1]}].',
f'in the required range: [{self.mean_in_range[0]}, {self.mean_in_range[1]}].',
diagnostics={
'dataset_mean': dataset_mean,
'mean_in_range': self.mean_in_range
Expand All @@ -374,6 +374,37 @@ def arg(cls) -> str:
return 'mean_in_range'


class AllowNoneValidator(BaseDefaultValidator):

def __init__(self, allow_none: bool, importance: str):
super(AllowNoneValidator, self).__init__(importance)
self.allow_none = allow_none

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
return True

def description(self) -> str:
if self.allow_none:
return 'No-op validator.'
return 'Validates that an output ;is not None'

def validate(self, data: Any) -> ValidationResult:
passes = True
if not self.allow_none:
if data is None:
passes = False
return ValidationResult(
passes=passes,
message=f'Data is not allowed to be None, got {data}' if not passes else 'Data is not None',
diagnostics={} # Nothing necessary here...
)

@classmethod
def arg(cls) -> str:
return 'allow_none'


AVAILABLE_DEFAULT_VALIDATORS = [
AllowNaNsValidatorPandasSeries,
DataInRangeValidatorPandasSeries,
Expand All @@ -385,6 +416,7 @@ def arg(cls) -> str:
MaxFractionNansValidatorPandasSeries,
MaxStandardDevValidatorPandasSeries,
MeanInRangeValidatorPandasSeries,
AllowNoneValidator,
]


Expand Down
4 changes: 2 additions & 2 deletions hamilton/function_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import typing_inspect

from hamilton import node
from hamilton import node, type_utils
from hamilton.data_quality import base as dq_base
from hamilton.data_quality import default_validators
from hamilton import function_modifiers_base
Expand Down Expand Up @@ -441,7 +441,7 @@ def ensure_output_types_match(fn: Callable, todo: Callable):
"""
annotation_fn = inspect.signature(fn).return_annotation
annotation_todo = inspect.signature(todo).return_annotation
if not issubclass(annotation_todo, annotation_fn):
if not type_utils.custom_subclass_check(annotation_fn, annotation_todo):
raise InvalidDecoratorException(f'Output types: {annotation_fn} and {annotation_todo} are not compatible')

@staticmethod
Expand Down
91 changes: 25 additions & 66 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,60 +13,19 @@

import typing_inspect

import hamilton.function_modifiers_base
from hamilton import function_modifiers_base
from hamilton import node
from hamilton.node import NodeSource, DependencyType
from hamilton import base
from hamilton import type_utils

logger = logging.getLogger(__name__)
BASE_ARGS_FOR_GENERICS = (typing.T,)


# kind of hacky for now but it will work
def is_submodule(child: ModuleType, parent: ModuleType):
return parent.__name__ in child.__name__


def custom_subclass_check(requested_type: Type[Type], param_type: Type[Type]):
"""This is a custom check around generics & classes. It probably misses a few edge cases.

We will likely need to revisit this in the future (perhaps integrate with graphadapter?)

:param requested_type: Candidate subclass
:param param_type: Type of parameter to check
:return: Whether or not this is a valid subclass.
"""
# handles case when someone is using primitives and generics
requested_origin_type = requested_type
param_origin_type = param_type
has_generic = False
if typing_inspect.is_generic_type(requested_type) or typing_inspect.is_tuple_type(requested_type):
requested_origin_type = typing_inspect.get_origin(requested_type)
has_generic = True
if typing_inspect.is_generic_type(param_type) or typing_inspect.is_tuple_type(param_type):
param_origin_type = typing_inspect.get_origin(param_type)
has_generic = True
if requested_origin_type == param_origin_type:
if has_generic: # check the args match or they do not have them defined.
requested_args = typing_inspect.get_args(requested_type)
param_args = typing_inspect.get_args(param_type)
if (requested_args and param_args
and requested_args != BASE_ARGS_FOR_GENERICS and param_args != BASE_ARGS_FOR_GENERICS):
return requested_args == param_args
return True

if ((typing_inspect.is_generic_type(requested_type) and typing_inspect.is_generic_type(param_type)) or
(inspect.isclass(requested_type) and typing_inspect.is_generic_type(param_type))):
# we're comparing two generics that aren't equal -- check if Mapping vs Dict
# or we're comparing a class to a generic -- check if Mapping vs dict
# the precedence is that requested will go into the param_type, so the param_type should be more permissive.
return issubclass(requested_type, param_type)
# classes - precedence is that requested will go into the param_type, so the param_type should be more permissive.
if inspect.isclass(requested_type) and inspect.isclass(param_type) and issubclass(requested_type, param_type):
return True
return False


def types_match(adapter: base.HamiltonGraphAdapter,
param_type: Type[Type],
required_node_type: Any) -> bool:
Expand All @@ -87,7 +46,7 @@ def types_match(adapter: base.HamiltonGraphAdapter,
return required_node_type == param_type
elif required_node_type == param_type:
return True
elif custom_subclass_check(required_node_type, param_type):
elif type_utils.custom_subclass_check(required_node_type, param_type):
return True
elif adapter.check_node_type_equivalence(required_node_type, param_type):
return True
Expand Down Expand Up @@ -131,7 +90,7 @@ def add_dependency(
f'{param_name}:{required_node.type}. All names & types must match.')
else:
# this is a user defined var
required_node = node.Node(param_name, param_type, node_source=NodeSource.EXTERNAL)
required_node = node.Node(param_name, param_type, node_source=node.NodeSource.EXTERNAL)
nodes[param_name] = required_node
# add edges
func_node.dependencies.append(required_node)
Expand All @@ -151,7 +110,7 @@ def create_function_graph(*modules: ModuleType, config: Dict[str, Any], adapter:

# create nodes -- easier to just create this in one loop
for func_name, f in functions:
for n in hamilton.function_modifiers_base.resolve_nodes(f, config):
for n in function_modifiers_base.resolve_nodes(f, config):
if n.name in config:
continue # This makes sure we overwrite things if they're in the config...
if n.name in nodes:
Expand All @@ -164,7 +123,7 @@ def create_function_graph(*modules: ModuleType, config: Dict[str, Any], adapter:
add_dependency(n, node_name, nodes, param_name, param_type, adapter)
for key in config.keys():
if key not in nodes:
nodes[key] = node.Node(key, Any, node_source=NodeSource.EXTERNAL)
nodes[key] = node.Node(key, Any, node_source=node.NodeSource.EXTERNAL)
return nodes


Expand Down Expand Up @@ -358,7 +317,7 @@ def next_nodes_function(n: node.Node) -> List[node.Node]:
# If inputs is None, we want to assume its required, as it is a compile-time dependency
if dep.user_defined and dep.name not in runtime_inputs and dep.name not in self.config:
_, dependency_type = n.input_types[dep.name]
if dependency_type == DependencyType.OPTIONAL:
if dependency_type == node.DependencyType.OPTIONAL:
continue
deps.append(dep)
return deps
Expand Down Expand Up @@ -423,41 +382,41 @@ def execute_static(nodes: Collection[node.Node],
if computed is None:
computed = {}

def dfs_traverse(node: node.Node, dependency_type: DependencyType = DependencyType.REQUIRED):
if node.name in computed:
def dfs_traverse(node_: node.Node, dependency_type: node.DependencyType = node.DependencyType.REQUIRED):
if node_.name in computed:
return
if node.name in overrides:
computed[node.name] = overrides[node.name]
if node_.name in overrides:
computed[node_.name] = overrides[node_.name]
return
for n in node.dependencies:
for n in node_.dependencies:
if n.name not in computed:
_, node_dependency_type = node.input_types[n.name]
_, node_dependency_type = node_.input_types[n.name]
dfs_traverse(n, node_dependency_type)

logger.debug(f'Computing {node.name}.')
if node.user_defined:
if node.name not in inputs:
if dependency_type != DependencyType.OPTIONAL:
raise NotImplementedError(f'{node.name} was expected to be passed in but was not.')
logger.debug(f'Computing {node_.name}.')
if node_.user_defined:
if node_.name not in inputs:
if dependency_type != node.DependencyType.OPTIONAL:
raise NotImplementedError(f'{node_.name} was expected to be passed in but was not.')
return
value = inputs[node.name]
value = inputs[node_.name]
else:
kwargs = {} # construct signature
for dependency in node.dependencies:
for dependency in node_.dependencies:
if dependency.name in computed:
kwargs[dependency.name] = computed[dependency.name]
try:
value = adapter.execute_node(node, kwargs)
value = adapter.execute_node(node_, kwargs)
except Exception as e:
logger.exception(f'Node {node.name} encountered an error')
logger.exception(f'Node {node_.name} encountered an error')
raise
computed[node.name] = value
computed[node_.name] = value

for final_var_node in nodes:
dep_type = DependencyType.REQUIRED
dep_type = node.DependencyType.REQUIRED
if final_var_node.user_defined:
# from the top level, we don't know if this UserInput is required. So mark as optional.
dep_type = DependencyType.OPTIONAL
dep_type = node.DependencyType.OPTIONAL
dfs_traverse(final_var_node, dep_type)
return computed

Expand Down
47 changes: 47 additions & 0 deletions hamilton/type_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import inspect
from typing import Type

import typing
import typing_inspect

BASE_ARGS_FOR_GENERICS = (typing.T,)


def custom_subclass_check(requested_type: Type[Type], param_type: Type[Type]):
"""This is a custom check around generics & classes. It probably misses a few edge cases.

We will likely need to revisit this in the future (perhaps integrate with graphadapter?)

:param requested_type: Candidate subclass
:param param_type: Type of parameter to check
:return: Whether or not this is a valid subclass.
"""
# handles case when someone is using primitives and generics
requested_origin_type = requested_type
param_origin_type = param_type
has_generic = False
if typing_inspect.is_generic_type(requested_type) or typing_inspect.is_tuple_type(requested_type):
requested_origin_type = typing_inspect.get_origin(requested_type)
has_generic = True
if typing_inspect.is_generic_type(param_type) or typing_inspect.is_tuple_type(param_type):
param_origin_type = typing_inspect.get_origin(param_type)
has_generic = True
if requested_origin_type == param_origin_type:
if has_generic: # check the args match or they do not have them defined.
requested_args = typing_inspect.get_args(requested_type)
param_args = typing_inspect.get_args(param_type)
if (requested_args and param_args
and requested_args != BASE_ARGS_FOR_GENERICS and param_args != BASE_ARGS_FOR_GENERICS):
return requested_args == param_args
return True

if ((typing_inspect.is_generic_type(requested_type) and typing_inspect.is_generic_type(param_type)) or
(inspect.isclass(requested_type) and typing_inspect.is_generic_type(param_type))):
# we're comparing two generics that aren't equal -- check if Mapping vs Dict
# or we're comparing a class to a generic -- check if Mapping vs dict
# the precedence is that requested will go into the param_type, so the param_type should be more permissive.
return issubclass(requested_type, param_type)
# classes - precedence is that requested will go into the param_type, so the param_type should be more permissive.
if inspect.isclass(requested_type) and inspect.isclass(param_type) and issubclass(requested_type, param_type):
return True
return False
5 changes: 5 additions & 0 deletions tests/test_default_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def test_resolve_default_validators_error(output_type, kwargs, importance):

(default_validators.AllowNaNsValidatorPandasSeries, False, pd.Series([.1, None]), False),
(default_validators.AllowNaNsValidatorPandasSeries, False, pd.Series([.1, .2]), True),

(default_validators.AllowNoneValidator, False, None, False),
(default_validators.AllowNoneValidator, False, 1, True),
(default_validators.AllowNoneValidator, True, None, True),
(default_validators.AllowNoneValidator, True, 1, True),
]
)
def test_default_data_validators(cls: Type[hamilton.data_quality.base.BaseDefaultValidator], param: Any, data: Any, should_pass: bool):
Expand Down
19 changes: 17 additions & 2 deletions tests/test_function_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Dict
from typing import Any, List, Dict, Set

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -338,6 +338,21 @@ def to_modify(param1: int, param2: int) -> int:
assert node.documentation == to_modify.__doc__


def test_does_function_modifier_complex_types():
def setify(**kwargs: List[int]) -> Set[int]:
return set(sum(kwargs.values(), []))

def to_modify(param1: List[int], param2: List[int]) -> int:
"""This sums the inputs it gets..."""
pass

annotation = does(setify)
node = annotation.generate_node(to_modify, {})
assert node.name == 'to_modify'
assert node.callable(param1=[1, 2, 3], param2=[4, 5, 6]) == {1, 2, 3, 4, 5, 6}
assert node.documentation == to_modify.__doc__


def test_model_modifier():
config = {
'my_column_model_params': {
Expand Down Expand Up @@ -672,7 +687,7 @@ def fn(input: pd.Series) -> pd.Series:
data_validators = [value for value in subdag_as_dict.values() if value.tags.get('hamilton.data_quality.contains_dq_results', False)]
assert len(data_validators) == 2 # One for each validator
first_validator, _ = data_validators
assert IS_DATA_VALIDATOR_TAG in first_validator.tags and first_validator.tags[IS_DATA_VALIDATOR_TAG] is True # Validates that all the required tags are included
assert IS_DATA_VALIDATOR_TAG in first_validator.tags and first_validator.tags[IS_DATA_VALIDATOR_TAG] is True # Validates that all the required tags are included
assert DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG in first_validator.tags and first_validator.tags[DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG] == 'fn'

# The final function should take in everything but only use the raw results
Expand Down
3 changes: 2 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import pytest

import hamilton.type_utils
import tests.resources.bad_functions
import tests.resources.config_modifier
import tests.resources.cyclic_functions
Expand Down Expand Up @@ -500,7 +501,7 @@ class Y(X):
])
def test_custom_subclass_check(param_type, required_type, expected):
"""Tests the custom_subclass_check"""
actual = graph.custom_subclass_check(required_type, param_type)
actual = hamilton.type_utils.custom_subclass_check(required_type, param_type)
assert actual == expected


Expand Down