diff --git a/graphtools/__init__.py b/graphtools/__init__.py index 8fc8a50..7384afc 100644 --- a/graphtools/__init__.py +++ b/graphtools/__init__.py @@ -1,2 +1,2 @@ -from .api import Graph, from_igraph +from .api import Graph, from_igraph, read_pickle from .version import __version__ diff --git a/graphtools/api.py b/graphtools/api.py index 5d48316..30a6f26 100644 --- a/graphtools/api.py +++ b/graphtools/api.py @@ -2,6 +2,8 @@ import warnings import tasklogger from scipy import sparse +import pickle +import pygsp from . import base from . import graphs @@ -283,3 +285,22 @@ def from_igraph(G, attribute="weight", **kwargs): K = G.get_adjacency(attribute=None).data return Graph(sparse.coo_matrix(K), precomputed='adjacency', **kwargs) + + +def read_pickle(path): + """Load pickled Graphtools object (or any object) from file. + + Parameters + ---------- + path : str + File path where the pickled object will be loaded. + """ + with open(path, 'rb') as f: + G = pickle.load(f) + + if not isinstance(G, base.BaseGraph): + warnings.warn( + 'Returning object that is not a graphtools.base.BaseGraph') + elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str): + G.logger = pygsp.utils.build_logger(G.logger) + return G diff --git a/graphtools/base.py b/graphtools/base.py index c2b2b2a..d0e1fa4 100644 --- a/graphtools/base.py +++ b/graphtools/base.py @@ -10,6 +10,8 @@ import warnings import numbers import tasklogger +import pickle +import sys try: import pandas as pd @@ -635,6 +637,23 @@ def to_igraph(self, attribute="weight", **kwargs): return ig.Graph.Weighted_Adjacency(utils.to_dense(W).tolist(), attr=attribute, **kwargs) + def to_pickle(self, path): + """Save the current Graph to a pickle. + + Parameters + ---------- + path : str + File path where the pickled object will be stored. + """ + if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph): + # python 3.5, 3.6 + logger = self.logger + self.logger = logger.name + with open(path, 'wb') as f: + pickle.dump(self, f) + if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph): + self.logger = logger + class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)): """Interface between BaseGraph and PyGSP. diff --git a/graphtools/graphs.py b/graphtools/graphs.py index 8c27d32..157da60 100644 --- a/graphtools/graphs.py +++ b/graphtools/graphs.py @@ -14,8 +14,6 @@ import tasklogger from .utils import (set_diagonal, - elementwise_minimum, - elementwise_maximum, set_submatrix) from .base import DataGraph, PyGSPGraph diff --git a/test/test_api.py b/test/test_api.py index 5379cf5..e5ae0d8 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -9,6 +9,8 @@ import igraph import numpy as np import graphtools +import tempfile +import os def test_from_igraph(): @@ -81,6 +83,56 @@ def test_to_igraph(): attribute="weight").data) == G.W) +def test_pickle_io_knngraph(): + G = build_graph(data, knn=5, decay=None) + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + + +def test_pickle_io_traditionalgraph(): + G = build_graph(data, knn=5, decay=10, thresh=0) + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + + +def test_pickle_io_landmarkgraph(): + G = build_graph(data, knn=5, decay=None, + n_landmark=data.shape[0] // 2) + L = G.landmark_op + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + np.testing.assert_array_equal(L, G_prime._landmark_op) + + +def test_pickle_io_pygspgraph(): + G = build_graph(data, knn=5, decay=None, use_pygsp=True) + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + G.to_pickle(path) + G_prime = graphtools.read_pickle(path) + assert isinstance(G_prime, type(G)) + assert G_prime.logger.name == G.logger.name + + +@warns(UserWarning) +def test_pickle_bad_pickle(): + import pickle + with tempfile.TemporaryDirectory() as tempdir: + path = os.path.join(tempdir, 'tmp.pkl') + with open(path, 'wb') as f: + pickle.dump('hello world', f) + G = graphtools.read_pickle(path) + + @warns(UserWarning) def test_to_pygsp_invalid_precomputed(): G = build_graph(data)