diff --git a/rplugin/python3/magma/jupyter_server_api.py b/rplugin/python3/magma/jupyter_server_api.py new file mode 100644 index 0000000..bb5b9de --- /dev/null +++ b/rplugin/python3/magma/jupyter_server_api.py @@ -0,0 +1,122 @@ +import json +import uuid +import re +import time +from queue import Empty as EmptyQueueException +from typing import Any, Dict +from threading import Thread +from queue import Queue +from urllib.parse import urlparse + +import requests +import websocket + +from magma.runtime_state import RuntimeState + + +class JupyterAPIClient: + def __init__(self, + url: str, + kernel_info: Dict[str, Any], + headers: Dict[str, str]): + self._base_url = url + self._kernel_info = kernel_info + self._headers = headers + + self._recv_queue: Queue[Dict[str, Any]] = Queue() + + def wait_for_ready(self, **kwargs): + while True: + response = requests.get(self._kernel_api_base, + headers=self._headers) + response = json.loads(response.text) + if response["execution_state"] in ("idle", "starting"): + return + time.sleep(0.1) + + + def start_channels(self) -> None: + parsed_url = urlparse(self._base_url) + self._socket = websocket.create_connection(f"ws://{parsed_url.hostname}:{parsed_url.port}" + f"/api/kernels/{self._kernel_info['id']}/channels", + header=self._headers, + ) + self._kernel_api_base = f"{self._base_url}/api/kernels/{self._kernel_info['id']}" + + self._iopub_recv_thread = Thread(target=self._recv_message) + self._iopub_recv_thread.start() + + def _recv_message(self) -> None: + while True: + response = json.loads(self._socket.recv()) + self._recv_queue.put(response) + + def get_iopub_msg(self, **kwargs): + if self._recv_queue.empty(): + raise EmptyQueueException + + response = self._recv_queue.get() + + return response + + def execute(self, code: str): + header = { + 'msg_type': 'execute_request', + 'msg_id': uuid.uuid1().hex, + 'session': uuid.uuid1().hex + } + + message = json.dumps({ + 'header': header, + 'parent_header': header, + 'metadata': {}, + 'content': { + 'code': code, + 'silent': False + } + }) + self._socket.send(message) + + def shutdown(self): + requests.delete(self._kernel_api_base, + headers=self._headers) + self._socket.close() + + +class JupyterAPIManager: + def __init__(self, + url: str, + ): + parsed_url = urlparse(url) + self._base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + token_part = re.search(r"token=(.*)", parsed_url.query) + + if token_part: + token = token_part.groups()[0] + self._headers = {'Authorization': 'token ' + token} + else: + # Run notebook with --NotebookApp.disable_check_xsrf="True". + self._headers = {} + + def start_kernel(self) -> None: + url = f"{self._base_url}/api/kernels" + response = requests.post(url, + headers=self._headers) + self._kernel_info = json.loads(response.text) + assert "id" in self._kernel_info, "Could not connect to Jupyter Server API. The URL specified may be incorrect." + self._kernel_api_base = f"{url}/{self._kernel_info['id']}" + + def client(self) -> JupyterAPIClient: + return JupyterAPIClient(url=self._base_url, + kernel_info=self._kernel_info, + headers=self._headers) + + def interrupt_kernel(self) -> None: + requests.post(f"{self._kernel_api_base}/interrupt", + headers=self._headers) + + def restart_kernel(self) -> None: + self.state = RuntimeState.STARTING + requests.post(f"{self._kernel_api_base}/restart", + headers=self._headers) diff --git a/rplugin/python3/magma/runtime.py b/rplugin/python3/magma/runtime.py index 7a994e6..c5b1c9b 100644 --- a/rplugin/python3/magma/runtime.py +++ b/rplugin/python3/magma/runtime.py @@ -1,5 +1,4 @@ -from typing import Optional, Tuple, List, Dict, Generator, IO, Any -from enum import Enum +from typing import Optional, Tuple, List, Dict, Generator, IO, Any, Union from contextlib import contextmanager from queue import Empty as EmptyQueueException import os @@ -8,6 +7,7 @@ import jupyter_client +from magma.runtime_state import RuntimeState from magma.options import MagmaOptions from magma.outputchunks import ( Output, @@ -18,20 +18,15 @@ to_outputchunk, clean_up_text ) - - -class RuntimeState(Enum): - STARTING = 0 - IDLE = 1 - RUNNING = 2 +from magma.jupyter_server_api import JupyterAPIClient, JupyterAPIManager class JupyterRuntime: state: RuntimeState kernel_name: str - kernel_manager: jupyter_client.KernelManager - kernel_client: jupyter_client.KernelClient + kernel_manager: Union[jupyter_client.KernelManager, JupyterAPIManager] + kernel_client: Union[jupyter_client.KernelClient, JupyterAPIClient] allocated_files: List[str] @@ -41,7 +36,18 @@ def __init__(self, kernel_name: str, options: MagmaOptions): self.state = RuntimeState.STARTING self.kernel_name = kernel_name - if ".json" not in self.kernel_name: + if kernel_name.startswith("http://") or kernel_name.startswith("https://"): + self.external_kernel = True + self.kernel_manager = JupyterAPIManager(kernel_name) + self.kernel_manager.start_kernel() + self.kernel_client = self.kernel_manager.client() + self.kernel_client.start_channels() + + self.allocated_files = [] + + self.options = options + + elif ".json" not in self.kernel_name: self.external_kernel = True self.kernel_manager = jupyter_client.manager.KernelManager( @@ -202,8 +208,8 @@ def tick(self, output: Optional[Output]) -> bool: assert isinstance( self.kernel_client, jupyter_client.blocking.client.BlockingKernelClient, - ) - + ) or isinstance( + self.kernel_client, JupyterAPIClient) if not self.is_ready(): try: self.kernel_client.wait_for_ready(timeout=0) diff --git a/rplugin/python3/magma/runtime_state.py b/rplugin/python3/magma/runtime_state.py new file mode 100644 index 0000000..123041d --- /dev/null +++ b/rplugin/python3/magma/runtime_state.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class RuntimeState(Enum): + STARTING = 0 + IDLE = 1 + RUNNING = 2