Skip to content

Commit

Permalink
Updates to connect to NDIF via streaming response
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Jan 1, 2024
1 parent 69689f0 commit aa5b5a3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
14 changes: 11 additions & 3 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import pickle

import requests
import socketio

from .. import CONFIG, pydantics
Expand Down Expand Up @@ -101,7 +104,7 @@ def blocking_request(self, request: pydantics.RequestModel):
f"wss://{CONFIG.API.HOST}",
socketio_path="/ws/socket.io",
transports=["websocket"],
wait_timeout=10
wait_timeout=10,
)

# Called when receiving a response from the server.
Expand All @@ -116,10 +119,15 @@ def blocking_response(data):
# If the status of the response is completed, update the local nodes that the user specified to save.
# Then disconnect and continue.
if response.status == pydantics.ResponseModel.JobStatus.COMPLETED:
for name, value in response.saves.items():
with requests.get(
url=f"https://{CONFIG.API.HOST}/retrieve/{response.id}", stream=True
) as stream:
result_response = pydantics.ResponseModel(**pickle.load(stream.raw))

for name, value in result_response.result.saves.items():
self.graph.nodes[name].value = value

self.output = response.output
self.output = result_response.result.output

sio.disconnect()
# Or if there was some error.
Expand Down
19 changes: 12 additions & 7 deletions src/nnsight/pydantics/Response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
from pydantic import BaseModel, field_validator


class ResultModel(BaseModel):
id: str
output: Any = None
saves: Any = None

@field_validator("output", "saves")
def unpickle(cls, value: bytes):
return pickle.loads(value)


class ResponseModel(BaseModel):
class JobStatus(Enum):
RECEIVED = "RECEIVED"
Expand All @@ -22,11 +32,11 @@ class JobStatus(Enum):
description: str

received: datetime = None
saves: Union[bytes, Any] = None
output: Union[bytes, Any] = None
session_id: str = None
blocking: bool = False

result: Union[bytes, ResultModel] = None

def __str__(self) -> str:
return f"{self.id} - {self.status.name}: {self.description}"

Expand All @@ -37,8 +47,3 @@ def log(self, logger: logging.Logger) -> ResponseModel:
logger.info(str(self))

return self

@field_validator("output", "saves")
@classmethod
def unpickle(cls, value):
return pickle.loads(value)

0 comments on commit aa5b5a3

Please sign in to comment.