Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a first draft of a generic IK functionality #4

Open
wants to merge 15 commits into
base: revisions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
config.yaml
saved_models/
datasets/
**/results/
Expand Down
122 changes: 122 additions & 0 deletions generative_graphik/utils/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import itertools
from typing import Callable, Optional

from liegroups.numpy.se3 import SE3Matrix
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.loader import DataLoader

from graphik.graphs import ProblemGraphRevolute
from graphik.robots import RobotRevolute
from graphik.utils import graph_from_pos
from generative_graphik.utils.dataset_generation import generate_data_point_from_pose, create_dataset_from_data_points
from generative_graphik.utils.get_model import get_model
from generative_graphik.utils.torch_to_graphik import joint_transforms_to_t_zero


def _default_cost_function(T_desired: torch.Tensor, T_eef: torch.Tensor) -> torch.Tensor:
"""
The default cost function for the inverse kinematics problem. It is the sum of the squared errors between the
desired and actual end-effector poses.

:param T_desired: The desired end-effector pose.
:param T_eef: The actual end-effector pose.
:return: The cost.
"""
return torch.sum((T_desired - T_eef) ** 2)


def _get_goal_idx(num_robots, samples_per_robot, batch_size, num_batch, idx_batch):
num_sample = num_batch * batch_size + idx_batch
return num_sample % samples_per_robot

def _get_robot_idx(num_robots, samples_per_robot, batch_size, num_batch, idx_batch):
num_sample = num_batch * batch_size + idx_batch
return num_sample // samples_per_robot


def ik(kinematic_chains: torch.tensor,
goals: torch.tensor,
samples: int = 16,
return_all: bool = False,
ik_cost_function: Callable = _default_cost_function,
batch_size: int = 64,
) -> torch.Tensor:
"""
This function takes robot kinematics and any number of goals and solves the inverse kinematics, using graphIK.

:param kinematic_chains: A tensor of shape (nR, N, 4, 4) containing the joint transformations of nR robots with N
joints each.
:param goals: A tensor of shape (nR, nG, 4, 4) containing the desired end-effector poses.
:param samples: The number of samples to use for the forward pass of the model.
:param return_all: If True, returns all the samples from the forward pass, so the resulting tensor has a shape
nR x nG x samples x nJ. If False, returns the best one only, so the resulting tensor has a shape nR x nG x nJ.
:param ik_cost_function: The cost function to use for the inverse kinematics problem if return_all is False.
:return: See return_all for info.
"""
device = kinematic_chains.device
model = get_model().to(device)

assert len(kinematic_chains.shape) == 4, f'Expected 4D tensor, got {kinematic_chains.shape}'
nr, nj, _, _ = kinematic_chains.shape
_, nG, _, _ = goals.shape
eef = f'p{nj}'

t_zeros = {i: joint_transforms_to_t_zero(kinematic_chains[i], [f'p{j}' for j in range(1 + nj)], se3type='numpy') for
i in range(nr)}
robots = {i: RobotRevolute({'num_joints': nj, 'T_zero': t_zeros[i]}) for i in range(nr)}
graphs = {i: ProblemGraphRevolute(robots[i]) for i in range(nr)}
if return_all:
q = torch.zeros((nr, nG, samples, nj), device=device)
else:
q = torch.zeros((nr, nG, nj), device=device)

problems = list()
for i, j in itertools.product(range(nr), range(nG)):
graph = graphs[i]
goal = goals[i, j]
problems.append(generate_data_point_from_pose(graph, goal))

# FIXME: Create one data point per sample until forward_eval works correctly with more than one sample
problems_times_samples = list(itertools.chain.from_iterable(zip(*[problems] * samples)))
data = create_dataset_from_data_points(problems_times_samples)
batch_size_forward = batch_size * samples
loader = DataLoader(data, batch_size=batch_size_forward, shuffle=False, num_workers=0)

for i, problem in enumerate(loader):
problem = model.preprocess(problem)
b = len(problem) # The actual batch size (might be smaller than batch_size_forward at the end of the dataset)
num_nodes_per_graph = int(problem.num_nodes / b)
P_all_ = model.forward_eval(
x=problem.pos,
h=torch.cat((problem.type, problem.goal_data_repeated_per_node), dim=-1),
edge_attr=problem.edge_attr,
edge_attr_partial=problem.edge_attr_partial,
edge_index=problem.edge_index_full,
partial_goal_mask=problem.partial_goal_mask,
nodes_per_single_graph=num_nodes_per_graph,
batch_size=b,
num_samples=1
).squeeze()
# Rearrange, s.t. we have problem_nr x sample_nr x node_nr x 3
P_all = P_all_.view(b // samples, samples, num_nodes_per_graph, 3)

for idx in range(b // samples):
idx_robot = _get_robot_idx(nr, nG, batch_size, i, idx)
idx_goal = _get_goal_idx(nr, nG, batch_size, i, idx)
graph = graphs[idx_robot]
goal = goals[idx_robot, idx_goal]
goalse3 = SE3Matrix.from_matrix(goal.detach().cpu().numpy(), normalize=True)
best = float('inf')
for sample in range(samples):
P = P_all[idx, sample, ...]
q_s = graph.joint_variables(graph_from_pos(P.detach().cpu().numpy(), graph.node_ids), {eef: goalse3})
if return_all:
q[idx_robot, idx_goal, sample] = torch.tensor([q_s[key] for key in robots[idx_robot].joint_ids[1:]], device=device)
T_ee = robots[idx_robot].pose(q_s, eef)
cost = ik_cost_function(goal, torch.tensor(T_ee.as_matrix()).to(goal))
if cost < best:
best = cost
q[idx_robot, idx_goal] = torch.tensor([q_s[key] for key in robots[idx_robot].joint_ids[1:]], device=device)
return q
101 changes: 88 additions & 13 deletions generative_graphik/utils/dataset_generation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Union
from typing import Iterable, List, Union

from liegroups.numpy.se2 import SE2Matrix
from liegroups.numpy.se3 import SE3Matrix
import numpy as np
import os
from tqdm import tqdm
from dataclasses import dataclass
from dataclasses import dataclass, fields

import torch
from torch_geometric.data import InMemoryDataset, Data
Expand Down Expand Up @@ -57,31 +60,94 @@ class StructData:
edge_index_full: Union[List[torch.Tensor], torch.Tensor]
T0: Union[List[torch.Tensor], torch.Tensor]

def generate_data_point_from_pose(graph, T_ee):
struct_data = generate_struct_data(graph)

def create_dataset_from_data_points(data_points: Iterable[Data]) -> CachedDataset:
"""Takes an iterable of Data objects and returns a CachedDataset by concatenating them."""
data = tuple(data_points)
types = torch.cat([d.type for d in data], dim=0)
T0 = torch.cat([d.T0 for d in data], dim=0).reshape(-1, 4, 4)
device = T0.device
num_joints = torch.concat([d.num_joints for d in data])
num_nodes = torch.tensor([d.num_nodes for d in data], device=device)
num_edges = torch.tensor([d.num_edges for d in data], device=device)

P = torch.cat([d.pos for d in data], dim=0)
distances = torch.cat([d.edge_attr for d in data], dim=0)
T_ee = torch.stack([d.T_ee for d in data], dim=0)
masks = torch.cat([d.partial_mask for d in data], dim=-1)
edge_index_full = torch.cat([d.edge_index_full for d in data], dim=-1)
partial_goal_mask = torch.cat([d.partial_goal_mask for d in data], dim=-1)

node_slice = torch.cat([torch.tensor([0], device=device), (num_nodes).cumsum(dim=-1)])
joint_slice = torch.cat([torch.tensor([0], device=device), (num_joints).cumsum(dim=-1)])
frame_slice = torch.cat([torch.tensor([0], device=device), (num_joints + 1).cumsum(dim=-1)])
robot_slice = torch.arange(num_joints.size(0) + 1, device=device)
edge_full_slice = torch.cat([torch.tensor([0], device=device), (num_edges).cumsum(dim=-1)])

slices = {
"edge_attr": edge_full_slice,
"pos": node_slice,
"type": node_slice,
"T_ee": robot_slice,
"num_joints": robot_slice,
"partial_mask": edge_full_slice,
"partial_goal_mask": node_slice,
"edge_index_full": edge_full_slice,
"M": frame_slice,
"q_goal": joint_slice,
}

data = Data(
type=types,
pos=P,
edge_attr=distances,
T_ee=T_ee,
num_joints=num_joints.type(torch.int32),
partial_mask=masks,
partial_goal_mask=partial_goal_mask,
edge_index_full=edge_index_full.type(torch.int32),
M=T0,
)

return CachedDataset(data, slices)

def generate_data_point_from_pose(graph, T_ee, device = None) -> Data:
"""
Generates a data point (~problem) from a problem graph and a desired end-effector pose.
"""
if isinstance(T_ee, torch.Tensor):
if device is None:
device = T_ee.device
T_ee = T_ee.detach().cpu().numpy()
if isinstance(T_ee, np.ndarray):
if T_ee.shape == (4, 4):
T_ee = SE3Matrix.from_matrix(T_ee, normalize=True)
else:
raise ValueError(f"Expected T_ee to be of shape (4, 4) or be SEMatrix, got {T_ee.shape}")
struct_data = generate_struct_data(graph, device)
num_joints = torch.tensor([struct_data.num_joints])
edge_index_full = struct_data.edge_index_full
edge_index_full = struct_data.edge_index_full.to(dtype=torch.int32, device=device)
T0 = struct_data.T0

# Build partial graph nodes
G_partial = graph.from_pose(T_ee)
T_ee = torch.from_numpy(T_ee.as_matrix()).type(torch.float32)
T_ee = torch.from_numpy(T_ee.as_matrix()).to(dtype=torch.float32, device=device)
P = np.array([p[1] for p in list(G_partial.nodes.data('pos', default=np.array([0,0,0])))])
P = torch.from_numpy(P).type(torch.float32)
P = torch.from_numpy(P).to(dtype=torch.float32, device=device)

# Build distances of partial graph
distances = np.sqrt(distance_matrix_from_graph(G_partial))
# Remove self-loop
distances = distances[~np.eye(distances.shape[0],dtype=bool)].reshape(distances.shape[0],-1)
distances = torch.from_numpy(distances).type(torch.float32)
distances = torch.from_numpy(distances).to(dtype=torch.float32, device=device)
# Remove filler NetworkX extra 1s
distances = struct_data.partial_mask * distances.reshape(-1)
return Data(
pos=P,
edge_index_full=edge_index_full.type(torch.int32),
edge_index_full=edge_index_full,
edge_attr=distances.unsqueeze(1),
T_ee=T_ee,
num_joints=num_joints.type(torch.int32),
num_joints=num_joints.to(dtype=torch.int32, device=device),
q_goal=None,
partial_mask=struct_data.partial_mask,
partial_goal_mask=struct_data.partial_goal_mask,
Expand Down Expand Up @@ -118,7 +184,7 @@ def generate_data_point(graph):
)


def generate_struct_data(graph):
def generate_struct_data(graph, device=None):

robot = graph.robot
dof = robot.n
Expand Down Expand Up @@ -153,7 +219,7 @@ def generate_struct_data(graph):
mask_gen[edge_index_full[0], edge_index_full[1]] > 0
) # get full elements from matrix (same order as generated)

return StructData(
data = StructData(
type=type,
num_joints=num_joints,
num_edges=num_edges,
Expand All @@ -163,6 +229,15 @@ def generate_struct_data(graph):
edge_index_full=edge_index_full,
T0=T0,
)
if device is None:
return data
data = StructData(**{
f.name: getattr(data, f.name).to(device)
if isinstance(getattr(data, f.name), torch.Tensor)
else getattr(data, f.name)
for f in fields(data)
})
return data


def generate_specific_robot_data(robots, num_examples, params):
Expand Down Expand Up @@ -479,4 +554,4 @@ def main(args):

if __name__ == "__main__":
args = parse_data_generation_args()
main(args)
main(args)
38 changes: 38 additions & 0 deletions generative_graphik/utils/get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from argparse import Namespace
import json
from pathlib import Path
from typing import Dict

import torch
import yaml

from generative_graphik.model import Model

_model = None # Use get_model to access the model
PROJECT_DIR = Path(__file__).resolve().parents[2]
CONFIG_DIR = PROJECT_DIR.joinpath('config.yaml')


def get_config() -> Dict:
"""Loads the configuration file"""
with CONFIG_DIR.open('r') as f:
return yaml.safe_load(f)


def get_model() -> Model:
"""Loads the model specified in the configuration file or returns the cached model."""
global _model
if _model is not None:
return _model
config = get_config()
d = Path(config['model'])
if torch.cuda.is_available():
state_dict = torch.load(d.joinpath('net.pth'), map_location='cuda')
else:
state_dict = torch.load(d.joinpath('net.pth'), map_location='cpu')
with d.joinpath('hyperparameters.txt').open('r') as f:
args = Namespace(**json.load(f))
model = Model(args)
model.load_state_dict(state_dict)
_model = model
return model
51 changes: 51 additions & 0 deletions generative_graphik/utils/torch_to_graphik.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Dict, Sequence, Union

from liegroups.numpy.se3 import SE3Matrix
from liegroups.torch.se3 import SE3Matrix as SE3MatrixTorch
import torch
from torch_geometric.data import Data

from generative_graphik.utils.dataset_generation import StructData


def define_ik_data(robot_data: StructData, goals: torch.Tensor) -> Data:
"""
This function takes a robot and a set of goals and returns a data point for every goal.

:param robot_data: A StructData object containing the robot's kinematics.
:param goals: A tensor of shape (nG, 4, 4) containing the desired end-effector poses.
"""
pass


def joint_transforms_from_t_zeros(T_zero: Dict[str, SE3Matrix], keys: Sequence[str], device: str = None) -> torch.Tensor:
"""Assumes that joints are alphabetically sorted"""
ret = torch.zeros((len(T_zero) - 1, 4, 4), device=device)
for i in range(1, len(keys)):
ret[i - 1] = torch.tensor(T_zero[keys[i-1]].inv().dot(T_zero[keys[i]]).as_matrix(), device=device)
return ret


def joint_transforms_to_t_zero(transforms: torch.Tensor,
keys: Sequence[str],
se3type: str = 'numpy') -> Dict[str, Union[SE3Matrix, SE3MatrixTorch]]:
"""
This function takes a tensor of joint transformations and returns the t_zero tensor, which describes the joint
pose in the world frame for the zero configuration.

:param transforms: A tensor of shape (nJ, 4, 4).
:param keys: The keys to use for the joint names. Assumes the first key is for the world frame, thus it will be
set to the identity.
:param se3type: The type of SE3 matrix to use. Either 'numpy' or 'torch'.
"""
nj = transforms.shape[0]
t_zero = transforms.clone()
for i in range(1, nj):
t_zero[i] = t_zero[i - 1] @ t_zero[i]
if se3type == 'torch':
t_zero = {keys[i+1]: SE3MatrixTorch.from_matrix(t_zero[i], normalize=True) for i in range(nj)}
t_zero[keys[0]] = SE3MatrixTorch.identity()
else:
t_zero = {keys[i+1]: SE3Matrix.from_matrix(t_zero[i].detach().cpu().numpy(), normalize=True) for i in range(nj)}
t_zero[keys[0]] = SE3Matrix.identity()
return t_zero
1 change: 1 addition & 0 deletions sample_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
model: '<path_to_your_pretrained_model_directory>'
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def main():
pass


if __name__ == '__main__':
main()
Loading