-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
52f36ae
commit 5d167f0
Showing
4 changed files
with
252 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .displayTile import DisplayTile | ||
from .world import World | ||
from .screen import Screen |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from ngclearn import Component, Compartment, numpy as jnp, resolver | ||
from ngcsimlib.logger import warn | ||
|
||
|
||
class DisplayTile(Component): | ||
def __init__(self, name, tile_size, highlight_img, goal_img, | ||
wall_brightness=100, wall_thickness=1, **kwargs): | ||
super().__init__(name, **kwargs) | ||
|
||
self.tile_size = tile_size | ||
self.highlight_img = highlight_img | ||
self.wall_thickness = wall_thickness | ||
self.goal_img = goal_img if goal_img is not None else jnp.zeros((tile_size, tile_size), dtype=jnp.uint8) | ||
self.wall_brightness = wall_brightness | ||
|
||
self.display = Compartment(jnp.zeros((tile_size, tile_size), dtype=jnp.uint8)) | ||
|
||
|
||
blank = jnp.zeros((self.tile_size, self.tile_size), dtype=jnp.uint8) | ||
north_wall = blank.at[0:self.wall_thickness, :].set( | ||
self.wall_brightness) | ||
south_wall = blank.at[-self.wall_thickness:, :].set( | ||
self.wall_brightness) | ||
east_wall = blank.at[:, -self.wall_thickness:].set(self.wall_brightness) | ||
west_wall = blank.at[:, 0:self.wall_thickness].set(self.wall_brightness) | ||
|
||
# self.displays = [blank, north_wall, south_wall, east_wall, west_wall, | ||
# self.highlight_img, self.goal_img] | ||
|
||
self.displays = jnp.array([north_wall.reshape(tile_size**2), | ||
east_wall.reshape(tile_size ** 2), | ||
south_wall.reshape(tile_size**2), | ||
west_wall.reshape(tile_size**2), | ||
self.highlight_img.reshape(tile_size**2), | ||
self.goal_img.reshape(tile_size**2)]) | ||
|
||
def update_display(self, layers): | ||
_layers = jnp.matmul(jnp.diag(layers), self.displays) | ||
self.display.set(jnp.max(_layers, axis=0).reshape(self.tile_size, self.tile_size)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from ngclearn import Compartment, Component, numpy as jnp | ||
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path | ||
|
||
|
||
class Screen(Component): | ||
auto_resolve = False | ||
def __init__(self, name, width, height, tile_size, **kwargs): | ||
super().__init__(name, **kwargs) | ||
self.width = width | ||
self.height = height | ||
self.tile_size = tile_size | ||
self._compartments = [] | ||
|
||
self.display = Compartment( | ||
jnp.zeros((tile_size * width, tile_size * height), dtype=jnp.uint8)) | ||
|
||
self.inputs = [] | ||
for y in range(height): | ||
self.inputs.append([]) | ||
for x in range(width): | ||
_c = Compartment(jnp.zeros((tile_size, tile_size), dtype=jnp.uint8)) | ||
self.__dict__[f"{name}_{y}_{x}"] = _c | ||
self.inputs[y].append(_c) | ||
self._compartments.append((f"{name}_{y}_{x}", y, x)) | ||
|
||
@staticmethod | ||
def build_advance_state(component): | ||
compartments = component._compartments | ||
tile_size = component.tile_size | ||
@staticmethod | ||
def _advance(display, **kwargs): | ||
for c, y, x in compartments: | ||
display = display.at[y * tile_size:(y + 1) * tile_size, | ||
x * tile_size:(x + 1) * tile_size].set(kwargs.get(c)) | ||
return display | ||
|
||
return _advance, ["display"], [], [], ["display"] + [c for c, _, _ in compartments] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from ngclearn import Context, numpy as jnp | ||
from ngclearn.evironment.displayTile import DisplayTile | ||
from ngclearn.evironment.screen import Screen | ||
from ngcsimlib.logger import warn | ||
|
||
class World(Context): | ||
def __init__(self, name, world_width, world_height, view_width=None, | ||
view_height=None): | ||
super().__init__(name) | ||
|
||
self.width = world_width | ||
self.height = world_height | ||
self.view_width = view_width if view_width is not None else world_width | ||
self.view_height = view_height if view_height is not None else ( | ||
world_height) | ||
if view_width is None and view_height is None: | ||
self.ego = False | ||
else: | ||
self.ego = True | ||
|
||
self.movement_map = jnp.zeros((self.height, self.width)) | ||
|
||
self._tiles = None | ||
self.screen = None | ||
self._wall_map = None | ||
|
||
self._current_position = (0, 0) | ||
self._goal_position = (0, 0) | ||
|
||
def reset(self, start_loc): | ||
self._current_position = start_loc | ||
self._update_position() | ||
|
||
@property | ||
def current_position(self): | ||
return self._current_position | ||
|
||
@property | ||
def goal_position(self): | ||
return self._goal_position | ||
|
||
def set_movement(self, locs, movable=True): | ||
for y, x in locs: | ||
self.movement_map = self.movement_map.at[y, x].set(1 if movable else 0) | ||
|
||
|
||
def _build_tiles(self, **kwargs): | ||
for y in range(self.view_height): | ||
self._tiles.append([]) | ||
for x in range(self.view_width): | ||
new_tile = DisplayTile(name=f"t_{y}_{x}", **kwargs) | ||
self._tiles[y].append(new_tile) | ||
self.screen.inputs[y][x] << new_tile.display | ||
|
||
def _build_wall_map(self): | ||
wall_map = jnp.zeros((self.height, self.width, 4), jnp.uint8) | ||
for y in range(self.height): | ||
for x in range(self.width): | ||
movable_tile = self.movement_map[y][x] | ||
if movable_tile == 0: | ||
continue | ||
|
||
#North | ||
dy = y - 1 | ||
if dy < 0: | ||
wall_map = wall_map.at[y, x, 0].set(1) | ||
elif self.movement_map[dy][x] == 0: | ||
wall_map = wall_map.at[y, x, 0].set(1) | ||
|
||
#east | ||
dx = x + 1 | ||
if dx >= self.width: | ||
wall_map = wall_map.at[y, x, 1].set(1) | ||
elif self.movement_map[y][dx] == 0: | ||
wall_map = wall_map.at[y, x, 1].set(1) | ||
|
||
#South | ||
dy = y + 1 | ||
if dy >= self.height: | ||
wall_map = wall_map.at[y, x, 2].set(1) | ||
elif self.movement_map[dy][x] == 0: | ||
wall_map = wall_map.at[y, x, 2].set(1) | ||
|
||
#West | ||
dx = x - 1 | ||
if dx < 0: | ||
wall_map = wall_map.at[y, x, 3].set(1) | ||
elif self.movement_map[y][dx] == 0: | ||
wall_map = wall_map.at[y, x, 3].set(1) | ||
self._wall_map = wall_map | ||
|
||
def get_view_radi(self): # North radius, East radius, South Radius, West radius | ||
ev = wv = (self.view_width-1) // 2 | ||
nv = sv = (self.view_height-1) // 2 | ||
return -nv, ev, sv, -wv | ||
|
||
def _update_tile(self, dy, dx, ay, ax): | ||
if 0 <= ay < self.height and 0 <= ax < self.width: | ||
layers = jnp.array([0, 0, 0, 0, 0, 0], dtype=jnp.uint8) | ||
if (ay, ax) == self._current_position: | ||
layers = layers.at[4].set(1) | ||
if (ay, ax) == self._goal_position: | ||
layers = layers.at[5].set(1) | ||
|
||
layers = layers.at[:4].set(self._wall_map[ay, ax, :]) | ||
self._tiles[dy][dx].update_display(layers) | ||
else: | ||
layers = jnp.array([0, 0, 0, 0, 0, 0], dtype=jnp.uint8) | ||
self._tiles[dy][dx].update_display(layers) | ||
|
||
def _update_position(self, old_poc=None): | ||
if self.ego: | ||
|
||
nr, er, sr, wr = self.get_view_radi() | ||
cy, cx = self._current_position | ||
for ty, y in enumerate(range(nr, sr+1)): | ||
for tx, x in enumerate(range(wr, er+1)): | ||
self._update_tile(ty, tx, cy+y, cx+x) | ||
|
||
else: | ||
if old_poc is None: | ||
for y in range(self.height): | ||
for x in range(self.width): | ||
self._update_tile(y, x, y, x) | ||
else: | ||
oy, ox = old_poc | ||
self._update_tile(oy, ox, oy, ox) | ||
ny, nx = self._current_position | ||
self._update_tile(ny, nx, ny, nx) | ||
|
||
def initialize(self, tile_size, | ||
highlight_img, | ||
goal_img, | ||
wall_brightness=100, | ||
wall_thickness=1, | ||
player_location=(0, 0), | ||
goal_location=(0, 0)): | ||
|
||
if self._tiles is not None: | ||
warn(f"{self.name} is already initialized, skipping") | ||
return | ||
|
||
self._current_position = player_location | ||
self._goal_position = goal_location | ||
|
||
self._tiles = [] | ||
self.screen = Screen("screen", height=self.view_height, | ||
width=self.view_width, tile_size=tile_size) | ||
self._build_tiles(tile_size=tile_size, highlight_img=highlight_img, | ||
goal_img=goal_img, wall_brightness=wall_brightness, | ||
wall_thickness=wall_thickness) | ||
self._build_wall_map() | ||
|
||
self._update_position() | ||
|
||
|
||
|
||
def move(self, action): | ||
dy, dx = action | ||
ly, lx = self._current_position | ||
|
||
ny = ly + dy | ||
nx = lx + dx | ||
|
||
if 0 <= ny < self.height and 0 <= nx < self.width: | ||
if self.movement_map[ny][nx] != 0: | ||
self._current_position = (ny, nx) | ||
else: | ||
return | ||
self._update_position((ly, lx)) | ||
|
||
|
||
|