diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index af16f0f1..265ebc15 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -32,11 +32,9 @@ class BaseSystemCollection(MutableSequence): _systems: list List of rod-like objects. - """ - - """ Developer Note ----- + Note ---- We can directly subclass a list for the @@ -174,19 +172,19 @@ def finalize(self): def synchronize(self, time: float): # Collection call _feature_group_synchronize for func in self._feature_group_synchronize: - func(time) + func(time=time) def constrain_values(self, time: float): # Collection call _feature_group_constrain_values for func in self._feature_group_constrain_values: - func(time) + func(time=time) def constrain_rates(self, time: float): # Collection call _feature_group_constrain_rates for func in self._feature_group_constrain_rates: - func(time) + func(time=time) def apply_callbacks(self, time: float, current_step: int): # Collection call _feature_group_callback for func in self._feature_group_callback: - func(time, current_step) + func(time=time, current_step=current_step) diff --git a/elastica/modules/connections.py b/elastica/modules/connections.py index 935b9ed5..2c93b7aa 100644 --- a/elastica/modules/connections.py +++ b/elastica/modules/connections.py @@ -5,7 +5,6 @@ Provides the connections interface to connect entities (rods, rigid bodies) using joints (see `joints.py`). """ -import functools import numpy as np from elastica.joint import FreeJoint @@ -60,16 +59,15 @@ def connect( sys_dofs = [self._systems[idx].n_elems for idx in sys_idx] # Create _Connect object, cache it and return to user - _connector = _Connect(*sys_idx, *sys_dofs) - _connector.set_index(first_connect_idx, second_connect_idx) - self._connections.append(_connector) - self._feature_group_synchronize.append_id(_connector) + _connect = _Connect(*sys_idx, *sys_dofs) + _connect.set_index(first_connect_idx, second_connect_idx) + self._connections.append(_connect) + self._feature_group_synchronize.append_id(_connect) - return _connector + return _connect def _finalize_connections(self): # From stored _Connect objects, instantiate the joints and store it - # dev : the first indices stores the # (first rod index, second_rod_idx, connection_idx_on_first_rod, connection_idx_on_second_rod) # to apply the connections to. @@ -82,8 +80,7 @@ def _finalize_connections(self): # FIXME: lambda t is included because OperatorType takes time as an argument def apply_forces(time): - return functools.partial( - connect_instance.apply_forces, + connect_instance.apply_forces( system_one=self._systems[first_sys_idx], index_one=first_connect_idx, system_two=self._systems[second_sys_idx], @@ -91,8 +88,7 @@ def apply_forces(time): ) def apply_torques(time): - return functools.partial( - connect_instance.apply_torques, + connect_instance.apply_torques( system_one=self._systems[first_sys_idx], index_one=first_connect_idx, system_two=self._systems[second_sys_idx], @@ -103,6 +99,9 @@ def apply_torques(time): connection, [apply_forces, apply_torques] ) + self._connections = [] + del self._connections + # Need to finally solve CPP here, if we are doing things properly # This is to optimize the call tree for better memory accesses # https://brooksandrew.github.io/simpleblog/articles/intro-to-graph-optimization-solving-cpp/ @@ -156,7 +155,7 @@ def __init__( def set_index(self, first_idx, second_idx): # TODO assert range # First check if the types of first rod idx and second rod idx variable are same. - assert type(first_idx) == type( + assert type(first_idx) is type( second_idx ), "Type of first_connect_idx :{}".format( type(first_idx) diff --git a/elastica/modules/contact.py b/elastica/modules/contact.py index 6585e6a8..5e333aac 100644 --- a/elastica/modules/contact.py +++ b/elastica/modules/contact.py @@ -5,7 +5,6 @@ Provides the contact interface to apply contact forces between objects (rods, rigid bodies, surfaces). """ -import functools from elastica.typing import SystemType, AllowedContactType @@ -71,14 +70,16 @@ def _finalize_contact(self) -> None: ) def apply_contact(time): - return functools.partial( - contact_instance.apply_contact, + contact_instance.apply_contact( system_one=self._systems[first_sys_idx], system_two=self._systems[second_sys_idx], ) self._feature_group_synchronize.add_operators(contact, [apply_contact]) + self._contacts = [] + del self._contacts + class _Contact: """ diff --git a/elastica/modules/feature_group.py b/elastica/modules/feature_group.py index 9dd87955..9578f954 100644 --- a/elastica/modules/feature_group.py +++ b/elastica/modules/feature_group.py @@ -1,4 +1,3 @@ -from typing import Callable from elastica.typing import OperatorType from collections.abc import Iterable @@ -11,10 +10,10 @@ def __init__(self): self._operator_collection: list[list[OperatorType]] = [] self._operator_ids: list[int] = [] - def __iter__(self) -> Callable[[...], None]: + def __iter__(self) -> OperatorType: if not self._operator_collection: raise RuntimeError("Feature group is not instantiated.") - operator_chain = itertools.chain(self._operator_collection) + operator_chain = itertools.chain.from_iterable(self._operator_collection) for operator in operator_chain: yield operator @@ -23,5 +22,5 @@ def append_id(self, feature): self._operator_collection.append([]) def add_operators(self, feature, operators: list[OperatorType]): - idx = self._operator_ids.index(feature) + idx = self._operator_ids.index(id(feature)) self._operator_collection[idx].extend(operators) diff --git a/elastica/modules/forcing.py b/elastica/modules/forcing.py index 1553d85b..e6e43f13 100644 --- a/elastica/modules/forcing.py +++ b/elastica/modules/forcing.py @@ -70,6 +70,9 @@ def _finalize_forcing(self): ext_force_torque, [apply_forces, apply_torques] ) + self._ext_forces_torques = [] + del self._ext_forces_torques + class _ExtForceTorque: """ diff --git a/tests/test_modules/test_base_system.py b/tests/test_modules/test_base_system.py index ebb3b7c3..76f63cd4 100644 --- a/tests/test_modules/test_base_system.py +++ b/tests/test_modules/test_base_system.py @@ -197,7 +197,18 @@ def test_forcing(self, load_collection, legal_forces): simulator_class.add_forcing_to(rod).using(legal_forces) simulator_class.finalize() # After finalize check if the created forcing object is instance of the class we have given. - assert isinstance(simulator_class._ext_forces_torques[-1][-1], legal_forces) + assert isinstance( + simulator_class._feature_group_synchronize._operator_collection[-1][ + -1 + ].func.__self__, + legal_forces, + ) + assert isinstance( + simulator_class._feature_group_synchronize._operator_collection[-1][ + -2 + ].func.__self__, + legal_forces, + ) # TODO: this is a dummy test for synchronize find a better way to test them simulator_class.synchronize(time=0) diff --git a/tests/test_modules/test_connections.py b/tests/test_modules/test_connections.py index bdc2715a..334ece41 100644 --- a/tests/test_modules/test_connections.py +++ b/tests/test_modules/test_connections.py @@ -150,7 +150,7 @@ def test_call_without_setting_connect_throws_runtime_error(self, load_connect): connect = load_connect with pytest.raises(RuntimeError) as excinfo: - connect() + connect.instantiate() assert "No connections provided" in str(excinfo.value) def test_call_improper_args_throws(self, load_connect): @@ -173,7 +173,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: - _ = connect() + _ = connect.instantiate() assert ( r"Unable to construct connection class.\nDid you provide all necessary joint properties?" == str(excinfo.value) @@ -327,21 +327,18 @@ def mock_init(self, *args, **kwargs): def test_connect_finalize_correctness(self, load_rod_with_connects): system_collection_with_connections, connect_cls = load_rod_with_connects + connect = system_collection_with_connections._connections[0] + assert connect._connect_cls == connect_cls system_collection_with_connections._finalize_connections() + assert ( + system_collection_with_connections._feature_group_synchronize._operator_ids[ + 0 + ] + == id(connect) + ) - for ( - fidx, - sidx, - fconnect, - sconnect, - connect, - ) in system_collection_with_connections._connections: - assert type(fidx) is int - assert type(sidx) is int - assert fconnect is None - assert sconnect is None - assert type(connect) is connect_cls + assert not hasattr(system_collection_with_connections, "_connections") @pytest.fixture def load_rod_with_connects_and_indices(self, load_system_with_connects): @@ -392,17 +389,17 @@ def test_connect_call_on_systems(self, load_rod_with_connects_and_indices): system_collection_with_connections_and_indices, connect_cls, ) = load_rod_with_connects_and_indices + mock_connections = [ + c for c in system_collection_with_connections_and_indices._connections + ] system_collection_with_connections_and_indices._finalize_connections() - system_collection_with_connections_and_indices._call_connections() - - for ( - fidx, - sidx, - fconnect, - sconnect, - connect, - ) in system_collection_with_connections_and_indices._connections: + system_collection_with_connections_and_indices.synchronize(0) + + for connection in mock_connections: + fidx, sidx, fconnect, sconnect = connection.id() + connect = connection.instantiate() + end_distance_vector = ( system_collection_with_connections_and_indices._systems[ sidx diff --git a/tests/test_modules/test_contact.py b/tests/test_modules/test_contact.py index 3979bebe..82c41f69 100644 --- a/tests/test_modules/test_contact.py +++ b/tests/test_modules/test_contact.py @@ -48,7 +48,7 @@ def test_call_without_setting_contact_throws_runtime_error(self, load_contact): contact = load_contact with pytest.raises(RuntimeError) as excinfo: - contact() + contact.instantiate() assert "No contacts provided to to establish contact between rod-like object id {0} and {1}, but a Contact was intended as per code. Did you forget to call the `using` method?".format( *contact.id() ) == str( @@ -75,7 +75,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: - _ = contact() + _ = contact.instantiate() assert ( r"Unable to construct contact class.\nDid you provide all necessary contact properties?" == str(excinfo.value) @@ -260,13 +260,15 @@ def mock_init(self, *args, **kwargs): def test_contact_finalize_correctness(self, load_rod_with_contacts): system_collection_with_contacts, contact_cls = load_rod_with_contacts + contact = system_collection_with_contacts._contacts[0].instantiate() + fidx, sidx = system_collection_with_contacts._contacts[0].id() system_collection_with_contacts._finalize_contact() - for fidx, sidx, contact in system_collection_with_contacts._contacts: - assert type(fidx) is int - assert type(sidx) is int - assert type(contact) is contact_cls + assert not hasattr(system_collection_with_contacts, "_contacts") + assert type(fidx) is int + assert type(sidx) is int + assert type(contact) is contact_cls @pytest.fixture def load_contact_objects_with_incorrect_order(self, load_system_with_contacts): @@ -339,19 +341,18 @@ def load_system_with_rods_in_contact(self, load_system_with_contacts): return system_collection_with_rods_in_contact def test_contact_call_on_systems(self, load_system_with_rods_in_contact): + from elastica.contact_forces import _calculate_contact_forces_rod_rod system_collection_with_rods_in_contact = load_system_with_rods_in_contact + mock_contacts = [c for c in system_collection_with_rods_in_contact._contacts] system_collection_with_rods_in_contact._finalize_contact() - system_collection_with_rods_in_contact._call_contacts(time=0) + system_collection_with_rods_in_contact.synchronize(time=0) - from elastica.contact_forces import _calculate_contact_forces_rod_rod + for _contact in mock_contacts: + fidx, sidx = _contact.id() + contact = _contact.instantiate() - for ( - fidx, - sidx, - contact, - ) in system_collection_with_rods_in_contact._contacts: system_one = system_collection_with_rods_in_contact._systems[fidx] system_two = system_collection_with_rods_in_contact._systems[sidx] external_forces_system_one = np.zeros_like(system_one.external_forces) diff --git a/tests/test_modules/test_forcing.py b/tests/test_modules/test_forcing.py index 67732767..bd384fc6 100644 --- a/tests/test_modules/test_forcing.py +++ b/tests/test_modules/test_forcing.py @@ -39,7 +39,7 @@ def test_call_without_setting_forcing_throws_runtime_error(self, load_forcing): forcing = load_forcing with pytest.raises(RuntimeError) as excinfo: - forcing(None) # None is the rod/system parameter + forcing.instantiate() # None is the rod/system parameter assert "No forcing" in str(excinfo.value) def test_call_improper_args_throws(self, load_forcing): @@ -62,7 +62,7 @@ def mock_init(self, *args, **kwargs): # Actual test is here, this should not throw with pytest.raises(TypeError) as excinfo: - _ = forcing() + _ = forcing.instantiate() assert "Unable to construct" in str(excinfo.value) @@ -166,7 +166,7 @@ def mock_init(self, *args, **kwargs): return scwf, MockForcing - def test_friction_plane_forcing_class_sorting(self, load_system_with_forcings): + def test_friction_plane_forcing_class(self, load_system_with_forcings): scwf = load_system_with_forcings @@ -196,19 +196,24 @@ def mock_init(self, *args, **kwargs): ) scwf.add_forcing_to(1).using(MockForcing, 2, 42) # index based forcing + # Now check if the Anisotropic friction and the MockForcing are in the list + assert scwf._ext_forces_torques[-1]._forcing_cls == MockForcing + assert scwf._ext_forces_torques[-2]._forcing_cls == AnisotropicFrictionalPlane scwf._finalize_forcing() - - # Now check if the Anisotropic friction is the last forcing class - assert isinstance(scwf._ext_forces_torques[-1][-1], AnisotropicFrictionalPlane) + assert not hasattr(scwf, "_ext_forces_torques") def test_constrain_finalize_correctness(self, load_rod_with_forcings): scwf, forcing_cls = load_rod_with_forcings + forcing_features = [f for f in scwf._ext_forces_torques] scwf._finalize_forcing() + assert not hasattr(scwf, "_ext_forces_torques") - for x, y in scwf._ext_forces_torques: - assert type(x) is int - assert type(y) is forcing_cls + for _forcing in forcing_features: + x = _forcing.id() + y = _forcing.instantiate() + assert isinstance(x, int) + assert isinstance(y, forcing_cls) @pytest.mark.xfail def test_constrain_finalize_sorted(self, load_rod_with_forcings):