Skip to content

Commit

Permalink
[geom] Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Oct 18, 2024
1 parent effb3d3 commit 478b49a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
3 changes: 1 addition & 2 deletions phi/geom/_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,7 @@ def __repr__(self):

def __getitem__(self, item) -> 'Cuboid':
item = _keep_vector(slicing_dict(self, item))
rotation = self._rotation_matrix[item] if self._rotation_matrix is not None else None
return Cuboid(self._center[item], self._half_size[item], rotation, size_variable=self._size_variable)
return Cuboid(self._center[item], self._half_size[item], math.slice(self._rotation_matrix, item), size_variable=self._size_variable)

@staticmethod
def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry':
Expand Down
5 changes: 1 addition & 4 deletions phi/geom/_geom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,7 @@ def _stack_geometries(geometries: Sequence[Geometry], set_op: str, dim=None) ->
elif len(geometries) == 1:
return geometries[0]
elif set_op == 'union' and all(type(g) == type(geometries[0]) and isinstance(g, PhiTreeNode) for g in geometries):
# ToDo look into using stacked attributes for intersection
attrs = variable_attributes(geometries[0])
values = {a: math.stack([getattr(g, a) for g in geometries], dim) for a in attrs}
return copy_with(geometries[0], **values)
return math.stack(tuple(geometries), dim, simplify=True)
else:
geos = math.layout(geometries, dim)
return GeometryStack(geos, set_op=set_op)
Expand Down
24 changes: 11 additions & 13 deletions phi/geom/_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self,
Args:
center: Sphere center as `Tensor` with `vector` dimension.
The spatial dimension order should be specified in the `vector` dimension via item names.
Can be left empty to specify dimensions via kwargs.
radius: Sphere radius as `float` or `Tensor`
**center_: Specifies center when the `center` argument is not given. Center position by dimension, e.g. `x=0.5, y=0.2`.
"""
Expand All @@ -41,6 +42,15 @@ def __init__(self,
self._radius_variable = radius_variable
assert 'vector' not in self._radius.shape, f"Sphere radius must not vary along vector but got {radius}"

def __all_attrs__(self) -> tuple:
return ('_center', '_radius')

def __variable_attrs__(self) -> tuple:
return ('_center', '_radius') if self._radius_variable else ('_center',)

def __value_attrs__(self) -> tuple:
return ()

@property
def shape(self):
if self._center is None or self._radius is None:
Expand Down Expand Up @@ -109,7 +119,7 @@ def approximate_closest_surface(self, location: Tensor) -> Tuple[Tensor, Tensor,
center_dist = math.vec_length(center_delta)
sgn_dist = center_dist - self_radius
if instance(self):
self_center, self_radius, sgn_dist, center_delta, center_dist = math.at_min((self.center, self.radius, sgn_dist, center_delta, center_dist), key=abs(sgn_dist), dim=instance)
self_center, self_radius, sgn_dist, center_delta, center_dist = math.at_min((self.center, self.radius, sgn_dist, center_delta, center_dist), key=abs(sgn_dist), dim=instance(self))
normal = math.safe_div(center_delta, center_dist)
default_normal = wrap([1] + [0] * (self.spatial_rank-1), self.shape['vector'])
normal = math.where(center_dist == 0, default_normal, normal)
Expand Down Expand Up @@ -146,18 +156,6 @@ def rotated(self, angle):
def scaled(self, factor: Union[float, Tensor]) -> 'Geometry':
return Sphere(self.center, self.radius * factor, radius_variable=self._radius_variable)

def __variable_attrs__(self):
return ('_center', '_radius') if self._radius_variable else ('_center',)

def __value_attrs__(self):
return '_center',

def __value_attrs__(self):
return '_center', '_radius'

def __value_attrs__(self):
return '_center', '_radius'

def __getitem__(self, item):
item = slicing_dict(self, item)
return Sphere(self._center[_keep_vector(item)], self._radius[item], radius_variable=self._radius_variable)
Expand Down

0 comments on commit 478b49a

Please sign in to comment.