generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Sep 30, 2023
1 parent
9c3c4e3
commit ba0591a
Showing
6 changed files
with
135 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from swarms_torch.particle_swarm import ParticleSwarmOptimization | ||
from swarms_torch.ant_colony_swarm import AntColonyOptimization | ||
from swarms_torch.queen_bee_swarm import QueenBeeGa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import unittest | ||
import torch | ||
|
||
from swarms_torch import AntColonyOptimization # Import your class | ||
|
||
class TestAntColonyOptimization(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.aco = AntColonyOptimization(goal="Hello ACO", num_ants=1000, num_iterations=10) | ||
|
||
def test_initialization(self): | ||
self.assertEqual(self.aco.goal.tolist(), [ord(c) for c in "Hello ACO"]) | ||
self.assertEqual(self.aco.pheromones.size(), torch.Size([1000])) | ||
self.assertEqual(self.aco.pheromones.tolist(), [1.0] * 1000) | ||
|
||
def test_fitness(self): | ||
solution = torch.tensor([ord(c) for c in "Hello ACO"], dtype=torch.float32) | ||
self.assertEqual(self.aco.fitness(solution).item(), 0) # Should be maximum fitness | ||
|
||
def test_update_pheromones(self): | ||
initial_pheromones = self.aco.pheromones.clone() | ||
self.aco.solutions = [torch.tensor([ord(c) for c in "Hello ACO"], dtype=torch.float32) for _ in range(1000)] | ||
self.aco.update_pheromones() | ||
# After updating, pheromones should not remain the same | ||
self.assertFalse(torch.equal(initial_pheromones, self.aco.pheromones)) | ||
|
||
def test_choose_next_path(self): | ||
path = self.aco.choose_next_path() | ||
# Path should be an integer index within the number of ants | ||
self.assertIsInstance(path, int) | ||
self.assertGreaterEqual(path, 0) | ||
self.assertLess(path, 1000) | ||
|
||
def test_optimize(self): | ||
solution = self.aco.optimize() | ||
self.assertIsInstance(solution, str) | ||
# Given enough iterations and ants, the solution should approach the goal. For short runs, this might not hold. | ||
# self.assertEqual(solution, "Hello ACO") | ||
|
||
def test_invalid_parameters(self): | ||
with self.assertRaises(ValueError): | ||
_ = AntColonyOptimization(num_ants=-5) | ||
with self.assertRaises(ValueError): | ||
_ = AntColonyOptimization(evaporation_rate=1.5) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
import torch | ||
|
||
from swarms_torch import ParticleSwarmOptimization # Import your class here | ||
|
||
class TestParticleSwarmOptimization(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.pso = ParticleSwarmOptimization(goal="Hello", n_particles=10) | ||
|
||
def test_initialization(self): | ||
self.assertEqual(self.pso.goal.tolist(), [ord(c) for c in "Hello"]) | ||
self.assertEqual(self.pso.particles.size(), (10, 5)) | ||
self.assertEqual(self.pso.velocities.size(), (10, 5)) | ||
|
||
def test_compute_fitness(self): | ||
particle = torch.tensor([ord(c) for c in "Hello"]) | ||
fitness = self.pso.compute_fitness(particle) | ||
self.assertEqual(fitness.item(), 1.0) | ||
|
||
def test_update(self): | ||
initial_particle = self.pso.particles.clone() | ||
self.pso.update() | ||
# After updating, particles should not remain the same (in most cases) | ||
self.assertFalse(torch.equal(initial_particle, self.pso.particles)) | ||
|
||
def test_optimize(self): | ||
initial_best_particle = self.pso.global_best.clone() | ||
self.pso.optimize(iterations=10) | ||
# After optimization, global best should be closer to the goal | ||
initial_distance = torch.norm((initial_best_particle - self.pso.goal).float()).item() | ||
final_distance = torch.norm((self.pso.global_best - self.pso.goal).float()).item() | ||
self.assertLess(final_distance, initial_distance) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import unittest | ||
import torch | ||
from swarms_torch import QueenBeeGa # Import the class | ||
|
||
class TestQueenBeeGa(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.optimizer = QueenBeeGa(goal="Hello QBGA", pop_size=50) | ||
|
||
def test_initialization(self): | ||
self.assertEqual(self.optimizer.goal, "Hello QBGA") | ||
self.assertEqual(self.optimizer.gene_length, len("Hello QBGA")) | ||
self.assertIsNone(self.optimizer.queen) | ||
self.assertIsNone(self.optimizer.queen_fitness) | ||
|
||
def test_encode_decode(self): | ||
encoded = QueenBeeGa.encode("Hello") | ||
decoded = QueenBeeGa.decode(encoded) | ||
self.assertEqual(decoded, "Hello") | ||
|
||
def test_evolution(self): | ||
initial_population = self.optimizer.pool.clone() | ||
self.optimizer._evolve() | ||
self.assertFalse(torch.equal(initial_population, self.optimizer.pool)) | ||
|
||
def test_run(self): | ||
initial_population = self.optimizer.pool.clone() | ||
self.optimizer.run(max_generations=10) | ||
self.assertNotEqual(QueenBeeGa.decode(self.optimizer.queen), QueenBeeGa.decode(initial_population[0])) | ||
|
||
def test_check_convergence(self): | ||
self.optimizer.pool = torch.stack([self.optimizer.target_gene] * 50) | ||
self.assertTrue(self.optimizer._check_convergence()) | ||
|
||
def test_invalid_parameters(self): | ||
with self.assertRaises(ValueError): | ||
_ = QueenBeeGa(mutation_prob=1.5) | ||
with self.assertRaises(ValueError): | ||
_ = QueenBeeGa(strong_mutation_rate=-0.5) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |