Skip to content

Commit

Permalink
Merge pull request #35 from KrishnaswamyLab/feature/to_pkl
Browse files Browse the repository at this point in the history
adding to_pickle and read_pickle
  • Loading branch information
scottgigante authored Mar 2, 2019
2 parents 31ee73e + 6abacc7 commit 40fefb1
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 3 deletions.
2 changes: 1 addition & 1 deletion graphtools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .api import Graph, from_igraph
from .api import Graph, from_igraph, read_pickle
from .version import __version__
21 changes: 21 additions & 0 deletions graphtools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import warnings
import tasklogger
from scipy import sparse
import pickle
import pygsp

from . import base
from . import graphs
Expand Down Expand Up @@ -283,3 +285,22 @@ def from_igraph(G, attribute="weight", **kwargs):
K = G.get_adjacency(attribute=None).data
return Graph(sparse.coo_matrix(K),
precomputed='adjacency', **kwargs)


def read_pickle(path):
"""Load pickled Graphtools object (or any object) from file.
Parameters
----------
path : str
File path where the pickled object will be loaded.
"""
with open(path, 'rb') as f:
G = pickle.load(f)

if not isinstance(G, base.BaseGraph):
warnings.warn(
'Returning object that is not a graphtools.base.BaseGraph')
elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str):
G.logger = pygsp.utils.build_logger(G.logger)
return G
19 changes: 19 additions & 0 deletions graphtools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import warnings
import numbers
import tasklogger
import pickle
import sys

try:
import pandas as pd
Expand Down Expand Up @@ -635,6 +637,23 @@ def to_igraph(self, attribute="weight", **kwargs):
return ig.Graph.Weighted_Adjacency(utils.to_dense(W).tolist(),
attr=attribute, **kwargs)

def to_pickle(self, path):
"""Save the current Graph to a pickle.
Parameters
----------
path : str
File path where the pickled object will be stored.
"""
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
# python 3.5, 3.6
logger = self.logger
self.logger = logger.name
with open(path, 'wb') as f:
pickle.dump(self, f)
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
self.logger = logger


class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)):
"""Interface between BaseGraph and PyGSP.
Expand Down
2 changes: 0 additions & 2 deletions graphtools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import tasklogger

from .utils import (set_diagonal,
elementwise_minimum,
elementwise_maximum,
set_submatrix)
from .base import DataGraph, PyGSPGraph

Expand Down
52 changes: 52 additions & 0 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import igraph
import numpy as np
import graphtools
import tempfile
import os


def test_from_igraph():
Expand Down Expand Up @@ -81,6 +83,56 @@ def test_to_igraph():
attribute="weight").data) == G.W)


def test_pickle_io_knngraph():
G = build_graph(data, knn=5, decay=None)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))


def test_pickle_io_traditionalgraph():
G = build_graph(data, knn=5, decay=10, thresh=0)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))


def test_pickle_io_landmarkgraph():
G = build_graph(data, knn=5, decay=None,
n_landmark=data.shape[0] // 2)
L = G.landmark_op
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))
np.testing.assert_array_equal(L, G_prime._landmark_op)


def test_pickle_io_pygspgraph():
G = build_graph(data, knn=5, decay=None, use_pygsp=True)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))
assert G_prime.logger.name == G.logger.name


@warns(UserWarning)
def test_pickle_bad_pickle():
import pickle
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
with open(path, 'wb') as f:
pickle.dump('hello world', f)
G = graphtools.read_pickle(path)


@warns(UserWarning)
def test_to_pygsp_invalid_precomputed():
G = build_graph(data)
Expand Down

0 comments on commit 40fefb1

Please sign in to comment.