Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass through attributes when creating metrics, dimensions, identifiers #18

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions mensor/backends/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class SQLDialect(object):
'sum': lambda x: "SUM({})".format(x),
'mean': lambda x: "AVG({})".format(x),
'sos': lambda x: "SUM(POW({}, 2))".format(x),
'count': lambda x: "COUNT({})".format(x)
'count': lambda x: "COUNT({})".format(x),
'1': lambda x: "1",
}

TEMPLATE_BASE = textwrap.dedent("""
Expand Down Expand Up @@ -187,7 +188,7 @@ def dialect(self):

def query(self, sql):
print(sql)
raise NotImplementedError("This SQLExecutor goes no further.")
raise NotImplementedError("DebugSQLExecutor prints SQL but cannot execute.")


class SQLMeasureProvider(MeasureProvider):
Expand All @@ -197,15 +198,14 @@ class SQLMeasureProvider(MeasureProvider):

@classmethod
def _on_registered(cls, key):
for agg in ['sum', 'mean', 'sos', 'count']:
for agg in ['sum', 'mean', 'sos', 'count', '1']:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps rename to "any"?

global_stats_registry.aggregations.register(
name=agg,
backend=key,
agg=eval("lambda field, dialect: dialect.AGG_METHODS['{}'](field)".format(agg), {}, {})
)

def __init__(self, *args, sql=None, executor=None, **kwargs):

if not executor:
executor = DebugSQLExecutor()
elif isinstance(executor, str):
Expand Down Expand Up @@ -256,14 +256,13 @@ def get_sql(self, *args, **kwargs):

def _get_ir(self, unit_type, measures, segment_by, where, joins, stats_registry, stats, covariates, **opts):
field_map = self._field_map(unit_type, measures, segment_by, joins)
rebase_agg = not unit_type.is_unique
sql = self._template_environment.get_template(self.dialect.TEMPLATE_BASE).render(
_sql=self._sql(unit_type=unit_type, measures=measures, segment_by=segment_by, where=where, joins=joins, stats=stats, covariates=covariates, **opts),
field_map=field_map,
provider=self,
table_name=self._table_name(unit_type),
dimensions=self._get_dimensions_sql(field_map, segment_by),
measures=self._get_measures_sql(field_map, unit_type, measures, rebase_agg, stats_registry, stats, covariates),
measures=self._get_measures_sql(field_map, unit_type, measures, stats_registry, stats, covariates),
groupby=self._get_groupby_sql(field_map, segment_by),
joins=joins,
constraints=self._get_where_sql(field_map, where),
Expand Down Expand Up @@ -325,29 +324,29 @@ def _get_dimensions_sql(self, field_map, dimensions):
)
return dims

def _get_measures_sql(self, field_map, unit_type, measures, rebase_agg, stats_registry, stats, covariates):
def _get_measures_sql(self, field_map, unit_type, measures, stats_registry, stats, covariates):
aggs = []

rebase_agg = not unit_type.is_unique
if rebase_agg and stats:
raise NotImplementedError("Computing stats and rebasing units simultaneously has not been implemented for the SQL backend.")
else:
for measure in measures:
if not measure.private:
for fieldname, transforms in measure.get_fields(unit_type=unit_type, stats=stats, stats_registry=stats_registry, rebase_agg=rebase_agg).items():

field = '1' if measure == 'count' else field_map['measures'][measure.via_name]
if transforms.get('pre_agg'):
field = transforms['pre_agg'](field, self.dialect)
field = transforms['agg'](field, self.dialect)
if transforms.get('post_agg'):
field = transforms['post_agg'](field, self.dialect)

aggs.append(
'{col_op} AS {f}'.format(
col_op=field,
f=self._col(fieldname),
)

for measure in measures:
if not measure.private:
for fieldname, transforms in measure.get_fields(unit_type=unit_type, stats=stats, stats_registry=stats_registry, rebase_agg=rebase_agg).items():

field = '1' if measure == 'count' else field_map['measures'][measure.via_name]
if transforms.get('pre_agg'):
field = transforms['pre_agg'](field, self.dialect)
field = transforms['agg'](field, self.dialect)
if transforms.get('post_agg'):
field = transforms['post_agg'](field, self.dialect)

aggs.append(
'{col_op} AS {f}'.format(
col_op=field,
f=self._col(fieldname),
)
)

return aggs

Expand Down
8 changes: 7 additions & 1 deletion mensor/measures/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,14 @@ def _features_lookup(self, unit_type, kind, attr_filter=None):
mask = None
if kind in ('foreign_key', 'reverse_foreign_key') and avail_unit_type == feature.name:
mask = unit_type.name
feature_attrs = feature.attrs
feature_attrs.update({
'unit_type': unit_type,
'mask': mask,
'kind': kind,
})
features.append(
_ResolvedFeature(feature.name, providers=[d.provider for d in instances], unit_type=unit_type, mask=mask, kind=kind)
_ResolvedFeature(providers=[d.provider for d in instances], **feature_attrs)
)
return features

Expand Down
50 changes: 27 additions & 23 deletions mensor/measures/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def __init__(self, name, unit_type=None, via=None, external=False, private=False
if self.ALLOW_ALL_ATTRIBUTES or attr in self.EXTRA_ATTRIBUTES:
setattr(self, attr, value)
else:
raise KeyError("No such attribute {}.".format(attr))
raise AttributeError(
"Cannot initialize {}<{}> with attribute '{}'.".format(self.__class__.__name__, self.name, attr)
)

def __getattr__(self, name):
if name.startswith('_'):
Expand Down Expand Up @@ -446,10 +448,7 @@ def transforms(self):
@transforms.setter
def transforms(self, transforms):
# TODO: Check structure of transforms dict
if not transforms:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonsubstantive, but I do like oneliners

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye... I like this better too. I think I left it uncollapsed because I was expecting to do more logic than would be elegant in a one-liner. For now, this works well.

self._transforms = {}
else:
self._transforms = transforms
self._transforms = {} if not transforms else transforms

@property
def as_external(self):
Expand Down Expand Up @@ -692,8 +691,8 @@ def desc(self):

class _Dimension(_ProvidedFeature):

def __init__(self, name, expr=None, default=None, desc=None, shared=False, partition=False, requires_constraint=False, provider=None):
_ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider)
def __init__(self, name, expr=None, default=None, desc=None, shared=False, partition=False, requires_constraint=False, provider=None, **attrs):
_ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider, **attrs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, perhaps we should be more explicit with which attrs get sent through

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if not shared and partition:
raise ValueError("Partitions must be shared.")
self.partition = partition
Expand Down Expand Up @@ -777,20 +776,25 @@ def matches(self, unit_type, reverse=False):
class _Measure(_ProvidedFeature):

def __init__(self, name, expr=None, default=None, desc=None,
distribution='normal', shared=False, provider=None):
_ProvidedFeature.__init__(self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider)
distribution='normal', shared=False, provider=None, **attrs):
_ProvidedFeature.__init__(
self, name, expr=expr, default=default, desc=desc, shared=shared, provider=provider,
**attrs
)
self.distribution = distribution

def transforms_for_unit_type(self, unit_type, stats_registry=None):
transforms = {
transforms = { # defaults
'pre_agg': None,
'agg': 'sum',
'post_agg': None,
'pre_rebase_agg': None,
'rebase_agg': 'sum',
'post_rebase_agg': None
}

if isinstance(self.transforms, dict):
transforms.update(self.transforms.get('_default', {}))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default transforms for all unit types?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remain unconvinced, I'm afraid, that this is a good idea, as it leads to confusing (and usually wrong) behaviour :/. Suppose we chose count as the default aggregation for some measure. That might make sense the first time it is aggregated, but on subsequent aggregations (which may occur due to multiple rebase operations, for example), you would be taking the count of counts at each aggregations, which would make the resulting measure likely meaningless.

transforms.update(self.transforms.get(unit_type, {}))

backend_aggs = stats_registry.aggregations.for_provider(self.provider)
Expand Down Expand Up @@ -828,14 +832,15 @@ def get_fields(self, unit_type=None, stats=True, rebase_agg=False, stats_registr
"""
assert stats_registry is not None
assert not (rebase_agg and stats)

if for_pandas:
from mensor.backends.pandas import PandasMeasureProvider
provider = PandasMeasureProvider
else:
provider = self.provider

transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry)
if stats:
transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry)
return OrderedDict([
(
(
Expand All @@ -850,18 +855,17 @@ def get_fields(self, unit_type=None, stats=True, rebase_agg=False, stats_registr
)
for field_name, agg_method in stats_registry.distribution_for_provider(self.distribution, provider).items()
])
else:
transforms = self.transforms_for_unit_type(unit_type, stats_registry=stats_registry)
return OrderedDict([
(
'{fieldname}|raw'.format(fieldname=self.fieldname(role=None, unit_type=unit_type if not rebase_agg else None)),
{
'agg': transforms['rebase_agg'] if rebase_agg else transforms['agg'],
'pre_agg': transforms['pre_rebase_agg'] if rebase_agg else transforms['pre_agg'],
'post_agg': transforms['post_rebase_agg'] if rebase_agg else transforms['post_agg'],
}
)
])

return OrderedDict([
(
'{fieldname}|raw'.format(fieldname=self.fieldname(role=None, unit_type=unit_type if not rebase_agg else None)),
{
'agg': transforms['rebase_agg'] if rebase_agg else transforms['agg'],
'pre_agg': transforms['pre_rebase_agg'] if rebase_agg else transforms['pre_agg'],
'post_agg': transforms['post_rebase_agg'] if rebase_agg else transforms['post_agg'],
}
)
])

@classmethod
def get_all_fields(self, measures, unit_type=None, stats=True, rebase_agg=False, stats_registry=None, for_pandas=False):
Expand Down