Source code for scitex_ssh

#!/usr/bin/env python3
"""SciTeX SSH - SSH primitives (exec/copy/attach) and gated reverse tunnels."""

from __future__ import annotations

import os
import socket
import subprocess

from ._allowlist import PolicyError
from ._allowlist import is_allowed as _is_allowed
from ._allowlist import require as _require_allowed
from ._primitives import SSHResult, attach, copy_from, copy_to, exec_remote

__all__ = [
    "__version__",
    "PolicyError",
    "SSHResult",
    "attach",
    "copy_from",
    "copy_to",
    "exec_remote",
    "setup",
    "remove",
    "status",
    "get_version",
]

from importlib.metadata import PackageNotFoundError as _PackageNotFoundError
from importlib.metadata import version as _version

try:
    __version__ = _version("scitex-ssh")
except _PackageNotFoundError:
    from pathlib import Path as _Path

    _pyproject = _Path(__file__).parent.parent.parent / "pyproject.toml"
    __version__ = "0.0.0+local"
    if _pyproject.exists():
        with open(_pyproject) as _f:
            for _line in _f:
                if _line.startswith("version"):
                    __version__ = _line.split('"')[1]
                    break

AVAILABLE = True

_SCRIPTS_DIR = os.path.join(os.path.dirname(__file__), "scripts")


def _run_script(
    script_name: str,
    args: list[str] | None = None,
    *,
    runner=None,
    scripts_dir: str | None = None,
) -> subprocess.CompletedProcess:
    """Run a bundled bash script.

    Parameters
    ----------
    runner : callable, optional
        Subprocess invoker with the same shape as ``subprocess.run``.
        Defaults to ``subprocess.run``. Pass a hand-rolled fake from
        tests to observe and stub the call without mocks.
    scripts_dir : str, optional
        Override the directory containing the bundled scripts. Defaults
        to the package's ``scripts/`` directory.
    """
    if runner is None:
        runner = subprocess.run
    base = scripts_dir if scripts_dir is not None else _SCRIPTS_DIR
    script_path = os.path.join(base, script_name)
    cmd = ["bash", script_path] + (args or [])
    return runner(cmd, capture_output=True, text=True)


def _local_host() -> str:
    return socket.gethostname().split(".")[0]


[docs] def setup( port: int, bastion_server: str | None = None, secret_key_path: str | None = None, *, host: str | None = None, runner=None, scripts_dir: str | None = None, ) -> dict: """Set up a persistent SSH reverse tunnel. Parameters ---------- port : int The remote port to forward (e.g. 2222). bastion_server : str, optional The bastion/relay server hostname or IP. Falls back to SCITEX_SSH_BASTION_SERVER env var. secret_key_path : str, optional Path to the SSH private key for authentication. Falls back to SCITEX_SSH_SECRET_KEY_PATH env var. Returns ------- dict Result with 'success', 'stdout', 'stderr' keys. Raises ------ ValueError If bastion_server or secret_key_path is not provided and the corresponding environment variable is not set. """ bastion_server = bastion_server or os.environ.get("SCITEX_SSH_BASTION_SERVER") secret_key_path = secret_key_path or os.environ.get("SCITEX_SSH_SECRET_KEY_PATH") if not bastion_server: raise ValueError( "bastion_server is required. Provide it as an argument or set " "SCITEX_SSH_BASTION_SERVER environment variable." ) if not secret_key_path: raise ValueError( "secret_key_path is required. Provide it as an argument or set " "SCITEX_SSH_SECRET_KEY_PATH environment variable." ) _require_allowed(host or _local_host(), "tunnels") result = _run_script( "setup-autossh-service.sh", ["-p", str(port), "-b", bastion_server, "-s", secret_key_path], runner=runner, scripts_dir=scripts_dir, ) return { "success": result.returncode == 0, "stdout": result.stdout, "stderr": result.stderr, }
[docs] def remove( port: int, *, host: str | None = None, runner=None, scripts_dir: str | None = None, ) -> dict: """Remove a persistent SSH reverse tunnel. Parameters ---------- port : int The remote port of the tunnel to remove. host : str, optional Local host label for allowlist gating. Defaults to local hostname. Returns ------- dict Result with 'success', 'stdout', 'stderr' keys. """ _require_allowed(host or _local_host(), "tunnels") result = _run_script( "remove-autossh-service.sh", ["-p", str(port)], runner=runner, scripts_dir=scripts_dir, ) return { "success": result.returncode == 0, "stdout": result.stdout, "stderr": result.stderr, }
[docs] def status(port: int | None = None, *, runner=None) -> dict: """Check status of SSH reverse tunnels. Parameters ---------- port : int, optional Specific port to check. If None, shows all tunnels. Returns ------- dict Result with 'success', 'stdout', 'stderr' keys. """ if port: cmd = [ "systemctl", "status", f"autossh-tunnel-{port}.service", "--no-pager", ] else: cmd = ["systemctl", "list-units", "autossh-tunnel-*", "--no-pager"] if runner is None: runner = subprocess.run result = runner(cmd, capture_output=True, text=True) return { "success": result.returncode == 0, "stdout": result.stdout, "stderr": result.stderr, }
[docs] def get_version() -> str: """Get scitex-ssh version.""" return __version__
# EOF