"""
Base experiment launcher class for OpenScope Bonsai workflows.
This module contains the core functionality for launching Bonsai workflows
with parameter management, process monitoring, and basic metadata collection.
"""
import os
import sys
import time
import signal
import logging
import datetime
import platform
import subprocess
import socket
import uuid
import hashlib
import atexit
import psutil
import threading
import shutil
import json
import argparse
from typing import Dict, List, Optional, Any
from decimal import Decimal
# Import AIND data schema utilities for standardized folder naming
try:
from aind_data_schema_models.data_name_patterns import build_data_name
AIND_DATA_SCHEMA_AVAILABLE = True
except ImportError:
AIND_DATA_SCHEMA_AVAILABLE = False
logging.warning("aind-data-schema-models not available. Using fallback folder naming.")
# Import Windows-specific modules for process management
try:
import win32job
import win32api
import win32con
WINDOWS_MODULES_AVAILABLE = True
except ImportError:
WINDOWS_MODULES_AVAILABLE = False
logging.warning("Windows modules not available. Process management will be limited.")
from ..utils.config_loader import ConfigLoader
from ..utils.git_manager import GitManager
from ..utils.process_monitor import ProcessMonitor
from .bonsai_interface import BonsaiInterface
[docs]
class BaseExperiment:
"""
Base class for OpenScope experimental launchers.
Provides core functionality for:
- Parameter loading and management
- Bonsai process management
- Repository setup and version control
- Process monitoring and memory management
- Basic output file generation """
[docs]
def __init__(self):
"""Initialize the base experiment with core functionality."""
self.platform_info = self._get_platform_info()
self.params = {}
self.bonsai_process = None
self.start_time = None
self.stop_time = None
self.config = {}
# Session tracking variables
self.subject_id = ""
self.user_id = ""
self.session_uuid = ""
self.session_directory = "" # Store the session output directory
self.script_checksum = None
self.params_checksum = None
# Process monitoring
self._percent_used = None
self._restarted = False
self.stdout_data = []
self.stderr_data = []
self._output_threads = []
self._logging_finalized = False # Flag to prevent duplicate logging
# Initialize utility classes
self.config_loader = ConfigLoader()
self.git_manager = GitManager()
self.process_monitor = ProcessMonitor()
self.bonsai_interface = BonsaiInterface()
# Windows job object for process management
self.hJob = None
if WINDOWS_MODULES_AVAILABLE:
self._setup_windows_job()
# Register exit handlers
atexit.register(self.cleanup)
signal.signal(signal.SIGTERM, self.signal_handler)
logging.info("BaseExperiment initialized")
def _get_platform_info(self) -> Dict[str, Any]:
"""Get system and version information for Windows."""
return {
"python": sys.version.split()[0],
"os": ("Windows", platform.release(), platform.version()),
"hardware": (platform.processor(), platform.machine()),
"computer_name": platform.node(),
"rig_id": os.environ.get('RIG_ID', socket.gethostname()),
}
def _setup_windows_job(self):
"""Set up Windows job object for process management."""
try:
self.hJob = win32job.CreateJobObject(None, "BonsaiJobObject")
extended_info = win32job.QueryInformationJobObject(
self.hJob, win32job.JobObjectExtendedLimitInformation
)
extended_info['BasicLimitInformation']['LimitFlags'] = (
win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
)
win32job.SetInformationJobObject(
self.hJob, win32job.JobObjectExtendedLimitInformation, extended_info
)
logging.info("Windows job object created for process management")
except Exception as e:
logging.warning(f"Failed to create Windows job object: {e}")
self.hJob = None
[docs]
def load_parameters(self, param_file: Optional[str]):
"""
Load parameters from a JSON file.
Args:
param_file: Path to the JSON parameter file
"""
if param_file:
with open(param_file, 'r') as f:
self.params = json.load(f)
logging.info(f"Loaded parameters from {param_file}")
# Generate parameter checksum for provenance tracking
with open(param_file, 'rb') as f:
self.params_checksum = hashlib.md5(f.read()).hexdigest()
logging.info(f"Parameter file checksum: {self.params_checksum}")
else:
logging.warning("No parameter file provided, using default parameters")
self.params = {}
# Collect runtime information (only for missing values)
runtime_info = self.collect_runtime_information() # Update parameters with runtime information
self.params.update(runtime_info)
# Load hardware configuration
self.config = self.config_loader.load_config(self.params)
# Extract subject_id and user_id (using runtime info and config as fallbacks)
self.subject_id = (
self.params.get("subject_id") or
(self.config.get("Behavior", {}).get("subject_id") if self.config else "") or
""
)
self.user_id = (
self.params.get("user_id") or
(self.config.get("Behavior", {}).get("user_id") if self.config else "") or
""
)
logging.info(f"Using subject_id: {self.subject_id}, user_id: {self.user_id}")
[docs]
def start_bonsai(self):
"""Start the Bonsai workflow using BonsaiInterface."""
logging.info(f"Subject ID: {self.subject_id}, User ID: {self.user_id}, Session UUID: {self.session_uuid}")
# Store current memory usage
vmem = psutil.virtual_memory()
self._percent_used = vmem.percent
try:
# Resolve all Bonsai paths relative to repository
resolved_paths = self._resolve_bonsai_paths()
# Create updated params with resolved paths
bonsai_params = self.params.copy()
bonsai_params.update(resolved_paths)
# Setup Bonsai environment (including installation if needed)
if not self.bonsai_interface.setup_bonsai_environment(bonsai_params):
raise RuntimeError("Failed to setup Bonsai environment")
# Get workflow path
workflow_path = self._get_workflow_path()
# Handle output directory generation (migrate from Bonsai GenerateRootLoggingPath)
# Construct arguments using BonsaiInterface - pass full params dict
workflow_args = self.bonsai_interface.construct_workflow_arguments(self.params)
# Start workflow using BonsaiInterface
self.bonsai_process = self.bonsai_interface.start_workflow(
workflow_path=workflow_path,
arguments=workflow_args,
output_path=self.session_directory
)
# Create threads to read output streams
self._start_output_readers()
# Assign process to Windows job object if available
if WINDOWS_MODULES_AVAILABLE and self.hJob:
self._assign_to_job_object()
self.start_time = datetime.datetime.now()
logging.info(f"Bonsai started at {self.start_time}")
# Log experiment start
logging.info(f"MID, {self.subject_id}, UID, {self.user_id}, Action, Executing, "
f"Checksum, {self.script_checksum}, Json_checksum, {self.params_checksum}")
# Monitor Bonsai process
self._monitor_bonsai()
except Exception as e:
logging.error(f"Failed to start Bonsai: {e}")
raise
def _resolve_bonsai_paths(self) -> Dict[str, str]:
"""
Resolve all Bonsai-related paths relative to the repository.
Returns:
Dictionary with resolved absolute paths for Bonsai components
"""
repo_path = self.git_manager.get_repository_path(self.params)
resolved_params = {}
# List of path parameters that should be resolved relative to repo
path_params = [
'bonsai_exe_path',
'bonsai_setup_script',
'bonsai_config_path'
]
for param_name in path_params:
param_value = self.params.get(param_name)
if param_value:
if os.path.isabs(param_value):
# Already absolute path
resolved_params[param_name] = param_value
elif repo_path:
# Resolve relative to repository
resolved_params[param_name] = os.path.join(repo_path, param_value)
else:
# No repository path available, use as-is
resolved_params[param_name] = param_value
return resolved_params
def _get_workflow_path(self) -> str:
"""Get the absolute path to the Bonsai workflow file."""
bonsai_path = self.params.get('bonsai_path')
if not bonsai_path:
raise ValueError("No Bonsai workflow path specified in parameters")
workflow_path = bonsai_path
if not os.path.isabs(workflow_path):
# Try to find workflow in repository
repo_path = self.git_manager.get_repository_path(self.params)
if repo_path:
workflow_path = os.path.join(repo_path, bonsai_path)
if not os.path.exists(workflow_path):
raise FileNotFoundError(f"Workflow file not found: {workflow_path}")
return workflow_path
def _start_output_readers(self):
"""Start threads to read stdout and stderr in real-time."""
self.stdout_data = []
self.stderr_data = []
def stdout_reader():
for line in iter(self.bonsai_process.stdout.readline, ''):
if line:
self.stdout_data.append(line.rstrip())
logging.info(f"Bonsai output: {line.rstrip()}")
self.bonsai_process.stdout.close()
def stderr_reader():
for line in iter(self.bonsai_process.stderr.readline, ''):
if line:
self.stderr_data.append(line.rstrip())
logging.error(f"Bonsai error: {line.rstrip()}")
self.bonsai_process.stderr.close()
self._output_threads = [
threading.Thread(target=stdout_reader),
threading.Thread(target=stderr_reader)
]
for thread in self._output_threads:
thread.daemon = True
thread.start()
def _assign_to_job_object(self):
"""Assign Bonsai process to Windows job object."""
try:
perms = win32con.PROCESS_TERMINATE | win32con.PROCESS_SET_QUOTA
hProcess = win32api.OpenProcess(perms, False, self.bonsai_process.pid)
win32job.AssignProcessToJobObject(self.hJob, hProcess)
logging.info(f"Bonsai process {self.bonsai_process.pid} assigned to job object")
except Exception as e:
logging.warning(f"Failed to assign process to job object: {e}")
def _monitor_bonsai(self):
"""Monitor the Bonsai process until it completes."""
logging.info("Monitoring Bonsai process...")
try:
# Monitor process with memory usage checking
self.process_monitor.monitor_process(
self.bonsai_process,
self._percent_used,
kill_callback=self.kill_process
)
# Wait for output threads to finish
for thread in self._output_threads:
thread.join(timeout=2.0)
# Check return code and log results
return_code = self.bonsai_process.returncode
if return_code != 0:
logging.error(f"Bonsai exited with code: {return_code}")
if self.stderr_data:
error_msg = "\n".join(self.stderr_data)
logging.error(f"Complete Bonsai error output:\n{error_msg}")
logging.error(f"MID, {self.subject_id}, UID, {self.user_id}, Action, Errored, "
f"Return_code, {return_code}")
else:
logging.info("Bonsai completed successfully")
if self.stderr_data:
warning_msg = "\n".join(self.stderr_data)
logging.warning(f"Bonsai reported warnings:\n{warning_msg}")
self.stop_time = datetime.datetime.now()
duration_min = (self.stop_time - self.start_time).total_seconds() / 60.0
logging.info(f"MID, {self.subject_id}, UID, {self.user_id}, Action, Completed, "
f"Duration_min, {round(duration_min, 2)}")
except Exception as e:
logging.error(f"Error monitoring Bonsai process: {e}")
self.stop()
[docs]
def kill_process(self):
"""Kill the Bonsai process immediately."""
if self.bonsai_process and self.bonsai_process.poll() is None:
logging.warning("Killing Bonsai process due to excessive memory usage")
try:
self.bonsai_process.kill()
if WINDOWS_MODULES_AVAILABLE:
subprocess.call(['taskkill', '/F', '/T', '/PID', str(self.bonsai_process.pid)])
except Exception as e:
logging.error(f"Error killing Bonsai process: {e}")
[docs]
def stop(self):
"""Stop the Bonsai process if it's running."""
self.stop_time = datetime.datetime.now()
if self.bonsai_process and self.bonsai_process.poll() is None:
logging.info("Stopping Bonsai process...")
try:
# Try graceful termination first
self.bonsai_process.terminate()
# Wait for termination
start_time = time.time()
while time.time() - start_time < 3:
if self.bonsai_process.poll() is not None:
logging.info("Bonsai process terminated gracefully")
break
time.sleep(0.1)
# Force kill if needed
if self.bonsai_process.poll() is None:
logging.warning("Forcing kill of Bonsai process")
self.bonsai_process.kill()
# Kill child processes on Windows
if WINDOWS_MODULES_AVAILABLE:
try:
subprocess.call(['taskkill', '/F', '/T', '/PID', str(self.bonsai_process.pid)])
except Exception as e:
logging.warning(f"Could not kill child processes: {e}")
except Exception as e:
logging.error(f"Error stopping Bonsai process: {e}")
# Finalize logging to flush all logs and close handlers
self.finalize_logging()
[docs]
def get_bonsai_errors(self) -> str:
"""Return any errors reported by Bonsai."""
if not self.stderr_data:
return "No errors reported by Bonsai."
return "\n".join(self.stderr_data)
[docs]
def cleanup(self):
"""Clean up resources when the script exits."""
logging.info("Cleaning up resources...")
self.stop()
[docs]
def signal_handler(self, sig, frame):
"""Handle Ctrl+C and other signals."""
logging.info("Received signal to terminate")
self.stop()
sys.exit(0)
[docs]
def post_experiment_processing(self) -> bool:
"""
Perform post-experiment processing specific to each rig type.
This method should be overridden in each rig-specific launcher.
Default implementation does nothing - each rig should implement
its own data reformatting logic here.
Returns:
True if successful, False otherwise
"""
logging.info("No post-experiment processing defined for this rig type")
return True
def _get_experiment_type_name(self) -> str:
"""
Get the name of the experiment type for logging and error messages.
Returns:
String name of the experiment type (default: "Bonsai")
"""
return "Bonsai"
[docs]
def run(self, param_file: Optional[str] = None) -> bool:
"""
Run the experiment with the given parameters.
Args:
param_file: Path to the JSON parameter file
Returns:
True if successful, False otherwise
"""
signal.signal(signal.SIGINT, self.signal_handler)
try:
# Set start time
self.start_time = datetime.datetime.now()
# Load parameters
self.load_parameters(param_file)
# Set up repository
if not self.git_manager.setup_repository(self.params):
logging.error("Repository setup failed")
return False
# Determine output directory for data saving
output_directory = self.determine_session_directory()
self.session_directory = output_directory # Store for post-processing
# Set up continuous logging to output directory
if output_directory:
centralized_log_dir = self.params.get("centralized_log_directory")
self.setup_continuous_logging(output_directory, centralized_log_dir)
# Save experiment metadata after logging is set up
self.save_experiment_metadata(output_directory, param_file)
# Start Bonsai
self.start_bonsai()
# Check for errors
if self.bonsai_process.returncode != 0:
logging.error(f"{self._get_experiment_type_name()} experiment failed")
return False
# Perform rig-specific post-processing
if not self.post_experiment_processing():
logging.warning("Post-experiment processing failed, but experiment data was collected")
return True
except Exception as e:
logging.exception(f"{self._get_experiment_type_name()} experiment failed: {e}")
return False
finally:
self.stop()
[docs]
def determine_session_directory(self) -> Optional[str]:
"""
Determine or generate output directory path using AIND data schema standards.
Args:
None
Returns:
Full path to the output directory, or None if not determinable
"""
try:
# Check if OutputFolder is already specified
if "OutputFolder" in self.params:
output_dir = self.params["OutputFolder"]
logging.info(f"Using OutputFolder from parameters: {output_dir}")
else:
logging.error("OutputFolder not specified in parameters")
raise ValueError("OutputFolder not specified in parameters")
subject_id = self.subject_id
# At this point, we should have root_folder and subject_id to generate a directory
if subject_id is None:
logging.error("Cannot generate output directory: missing root_folder or subject_id")
raise ValueError("Missing subject_id")
# Generate directory using AIND data schema standards
date_time_offset = datetime.datetime.now()
if AIND_DATA_SCHEMA_AVAILABLE:
try:
# Use AIND standard naming: {subject_id}_{datetime}
folder_name = build_data_name(
label=subject_id,
creation_datetime=date_time_offset
)
logging.info(f"Generated AIND-compliant folder name: {folder_name}")
except Exception as e:
logging.warning(f"Failed to use AIND data schema naming, falling back to default: {e}")
# Fallback to default naming
folder_name = f"{subject_id}_{date_time_offset.strftime('%Y-%m-%d_%H-%M-%S')}"
else:
# Fallback naming when AIND data schema is not available
folder_name = f"{subject_id}_{date_time_offset.strftime('%Y-%m-%d_%H-%M-%S')}"
logging.info(f"Using fallback folder name: {folder_name}")
# We also assign the session UUID to the output directory
self.session_uuid = folder_name
# Create full output directory path
output_dir = os.path.join(output_dir, folder_name)
# Create the directory if it doesn't exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logging.info(f"Created output directory: {output_dir}")
else:
logging.info(f"Output directory already exists: {output_dir}")
return output_dir
except Exception as e:
logging.error(f"Failed to determine/generate output directory: {e}")
return None
[docs]
@classmethod
def create_argument_parser(cls, description: str = None) -> argparse.ArgumentParser:
"""
Create a standard argument parser for experiment launchers.
Args:
description: Description for the argument parser
Returns:
Configured ArgumentParser instance
"""
if description is None:
description = f"Launch {cls.__name__} experiment"
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
'param_file',
nargs='?',
help='Path to the JSON parameter file. If not provided, will look for default parameter files.'
)
return parser
[docs]
@classmethod
def run_from_args(cls, args: argparse.Namespace) -> int:
"""
Run the experiment from parsed command line arguments.
Args:
args: Parsed command line arguments
Returns:
Exit code (0 for success, 1 for failure)
"""
# Set up basic logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
try:
# Validate parameter file if provided
if args.param_file and not os.path.exists(args.param_file):
logging.error(f"Parameter file not found: {args.param_file}")
return 1
# Create experiment instance
experiment = cls()
# Run the experiment
experiment_name = cls.__name__.replace('Experiment', '').replace('Launcher', '')
logging.info(f"Starting {experiment_name} with parameters: {args.param_file}")
success = experiment.run(args.param_file)
if success:
logging.info(f"===== {experiment_name.upper()} COMPLETED SUCCESSFULLY =====")
return 0
else:
logging.error(f"===== {experiment_name.upper()} FAILED =====")
logging.error("Check the logs above for error details.")
return 1
except KeyboardInterrupt:
logging.info("Experiment interrupted by user")
return 1
except Exception as e:
logging.error(f"Unexpected error: {e}")
return 1
[docs]
@classmethod
def main(cls, description: str = None, args: List[str] = None) -> int:
"""
Main entry point for experiment launchers.
Args:
description: Description for the argument parser
args: Command line arguments (defaults to sys.argv)
Returns:
Exit code (0 for success, 1 for failure)
"""
parser = cls.create_argument_parser(description)
parsed_args = parser.parse_args(args)
return cls.run_from_args(parsed_args)
[docs]
def setup_continuous_logging(self, output_directory: str, centralized_log_dir: Optional[str] = None):
"""
Set up continuous logging to output directory and optionally centralized location.
Args:
output_directory: Directory where experiment-specific logs should be saved
centralized_log_dir: Optional centralized logging directory
"""
try: # Create log filename with timestamp and session info
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
subject_id = self.params.get('subject_id')
log_filename = f"experiment_{self.session_uuid}.log"
# Set up logging format
log_format = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Get root logger
root_logger = logging.getLogger()
# 1. Set up file handler for output directory
output_log_path = os.path.join(output_directory, log_filename)
os.makedirs(os.path.dirname(output_log_path), exist_ok=True)
output_handler = logging.FileHandler(output_log_path)
output_handler.setLevel(logging.DEBUG)
output_handler.setFormatter(log_format)
root_logger.addHandler(output_handler)
logging.info(f"Continuous logging started: {output_log_path}")
# 2. Set up centralized logging if specified
if centralized_log_dir:
# Create centralized log directory structure: YYYY/MM/DD/
date_path = datetime.datetime.now().strftime('%Y/%m/%d')
centralized_dir = os.path.join(centralized_log_dir, date_path)
os.makedirs(centralized_dir, exist_ok=True)
centralized_log_path = os.path.join(centralized_dir, log_filename)
centralized_handler = logging.FileHandler(centralized_log_path)
centralized_handler.setLevel(logging.DEBUG) # Slightly higher level for centralized
centralized_handler.setFormatter(log_format)
root_logger.addHandler(centralized_handler)
logging.info(f"Centralized logging started: {centralized_log_path}")
# 3. Set up console handler if not already present
if not any(isinstance(h, logging.StreamHandler) for h in root_logger.handlers):
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(logging.Formatter('%(levelname)s - %(message)s'))
root_logger.addHandler(console_handler)
# Set overall logging level
root_logger.setLevel(logging.DEBUG)
# Log session information
logging.info("="*60)
logging.info("EXPERIMENT SESSION STARTED")
logging.info(f"Session UUID: {self.session_uuid}")
logging.info(f"Subject ID: {subject_id}")
logging.info(f"User ID: {self.user_id}")
logging.info(f"Platform: {self.platform_info}")
logging.info(f"Output Directory: {output_directory}")
if centralized_log_dir:
logging.info(f"Centralized Logs: {centralized_log_dir}")
logging.info("="*60)
except Exception as e:
print(f"Failed to set up continuous logging: {e}")
# Continue without file logging - at least console will work
[docs]
def finalize_logging(self):
"""
Finalize logging at the end of the experiment.
Logs final session information and closes file handlers.
"""
# Prevent duplicate finalization
if self._logging_finalized:
return
self._logging_finalized = True
try:
# Log final session information
logging.info("="*60)
logging.info("EXPERIMENT SESSION COMPLETED")
logging.info(f"Session UUID: {self.session_uuid}")
logging.info(f"Start Time: {self.start_time}")
logging.info(f"Stop Time: {self.stop_time}")
if self.start_time and self.stop_time:
duration = self.stop_time - self.start_time
logging.info(f"Duration: {duration}")
logging.info(f"Final Memory Usage: {psutil.virtual_memory().percent}%")
logging.info("="*60)
# Close and remove file handlers to ensure logs are flushed
root_logger = logging.getLogger()
handlers_to_remove = []
for handler in root_logger.handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
handlers_to_remove.append(handler)
for handler in handlers_to_remove:
root_logger.removeHandler(handler)
except Exception as e:
print(f"Error finalizing logging: {e}")