Entropyk/_bmad/external_model.rs

556 lines
15 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! External Component Model Interface
//!
//! This module provides support for external component models via:
//! - Dynamic library loading (.dll/.so) via FFI
//! - HTTP API calls to external services
//!
//! ## Architecture
//!
//! The external model interface allows integration of proprietary or vendor-supplied
//! component models that cannot be implemented natively in Rust.
//!
//! ## FFI Interface (DLL/SO)
//!
//! External libraries must implement the `entropyk_model` C ABI:
//!
//! ```c
//! // Required exported functions:
//! int entropyk_model_compute(double* inputs, double* outputs, int n_in, int n_out);
//! int entropyk_model_jacobian(double* inputs, double* jacobian, int n_in, int n_out);
//! const char* entropyk_model_name(void);
//! const char* entropyk_model_version(void);
//! ```
//!
//! ## HTTP API Interface
//!
//! External services must provide REST endpoints:
//!
//! - `POST /compute`: Accepts JSON with inputs, returns JSON with outputs
//! - `POST /jacobian`: Accepts JSON with inputs, returns JSON with Jacobian matrix
use crate::ComponentError;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
/// Configuration for an external model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExternalModelConfig {
/// Unique identifier for this model
pub id: String,
/// Model type (ffi or http)
pub model_type: ExternalModelType,
/// Number of inputs expected
pub n_inputs: usize,
/// Number of outputs produced
pub n_outputs: usize,
/// Optional timeout in milliseconds
#[serde(default = "default_timeout")]
pub timeout_ms: u64,
}
fn default_timeout() -> u64 {
5000
}
/// Type of external model interface.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ExternalModelType {
/// Dynamic library (.dll on Windows, .so on Linux, .dylib on macOS)
Ffi {
/// Path to the library file
library_path: PathBuf,
/// Optional function name prefix
function_prefix: Option<String>,
},
/// HTTP REST API
Http {
/// Base URL for the API
base_url: String,
/// Optional API key for authentication
api_key: Option<String>,
},
}
/// Trait for external model implementations.
///
/// This trait abstracts over FFI and HTTP interfaces, providing
/// a unified interface for the solver.
pub trait ExternalModel: Send + Sync {
/// Returns the model identifier.
fn id(&self) -> &str;
/// Returns the number of inputs.
fn n_inputs(&self) -> usize;
/// Returns the number of outputs.
fn n_outputs(&self) -> usize;
/// Computes outputs from inputs.
///
/// # Arguments
///
/// * `inputs` - Input values (length = n_inputs)
///
/// # Returns
///
/// Output values (length = n_outputs)
fn compute(&self, inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError>;
/// Computes the Jacobian matrix.
///
/// # Arguments
///
/// * `inputs` - Input values
///
/// # Returns
///
/// Jacobian matrix as a flat array (row-major, n_outputs × n_inputs)
fn jacobian(&self, inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError>;
/// Returns model metadata.
fn metadata(&self) -> ExternalModelMetadata;
}
/// Metadata about an external model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExternalModelMetadata {
/// Model name
pub name: String,
/// Model version
pub version: String,
/// Model description
pub description: Option<String>,
/// Input names/units
pub input_names: Vec<String>,
/// Output names/units
pub output_names: Vec<String>,
}
/// Errors from external model operations.
#[derive(Debug, Clone, thiserror::Error)]
pub enum ExternalModelError {
/// Library loading failed
#[error("Failed to load library: {0}")]
LibraryLoad(String),
/// Function not found in library
#[error("Function not found: {0}")]
FunctionNotFound(String),
/// Computation failed
#[error("Computation failed: {0}")]
ComputationFailed(String),
/// Invalid input dimensions
#[error("Invalid input dimensions: expected {expected}, got {actual}")]
InvalidInputDimensions {
/// Expected number of inputs
expected: usize,
/// Actual number received
actual: usize,
},
/// HTTP request failed
#[error("HTTP request failed: {0}")]
HttpError(String),
/// Timeout exceeded
#[error("Operation timed out after {0}ms")]
Timeout(u64),
/// JSON parsing error
#[error("JSON error: {0}")]
JsonError(String),
/// Model not initialized
#[error("Model not initialized")]
NotInitialized,
}
impl From<ExternalModelError> for ComponentError {
fn from(err: ExternalModelError) -> Self {
ComponentError::InvalidState(format!("External model error: {}", err))
}
}
/// Request body for HTTP compute endpoint.
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct ComputeRequest {
inputs: Vec<f64>,
}
/// Response from HTTP compute endpoint.
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ComputeResponse {
outputs: Vec<f64>,
}
/// Request body for HTTP Jacobian endpoint.
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct JacobianRequest {
inputs: Vec<f64>,
}
/// Response from HTTP Jacobian endpoint.
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct JacobianResponse {
jacobian: Vec<f64>,
}
/// FFI-based external model (stub implementation).
///
/// This is a placeholder that compiles without external dependencies.
/// Full FFI support requires the `libloading` crate and unsafe code.
#[cfg(not(feature = "ffi"))]
pub struct FfiModel {
config: ExternalModelConfig,
metadata: ExternalModelMetadata,
}
#[cfg(not(feature = "ffi"))]
impl FfiModel {
/// Creates a new FFI model (stub - returns error without ffi feature).
pub fn new(_config: ExternalModelConfig) -> Result<Self, ExternalModelError> {
Err(ExternalModelError::NotInitialized)
}
/// Creates with mock data for testing.
pub fn new_mock(
config: ExternalModelConfig,
metadata: ExternalModelMetadata,
) -> Result<Self, ExternalModelError> {
Ok(Self { config, metadata })
}
}
#[cfg(not(feature = "ffi"))]
impl ExternalModel for FfiModel {
fn id(&self) -> &str {
&self.config.id
}
fn n_inputs(&self) -> usize {
self.config.n_inputs
}
fn n_outputs(&self) -> usize {
self.config.n_outputs
}
fn compute(&self, _inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError> {
// Stub implementation
Ok(vec![0.0; self.config.n_outputs])
}
fn jacobian(&self, _inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError> {
// Stub implementation - returns identity matrix
let n = self.config.n_inputs * self.config.n_outputs;
Ok(vec![0.0; n])
}
fn metadata(&self) -> ExternalModelMetadata {
self.metadata.clone()
}
}
/// HTTP-based external model (stub implementation).
///
/// This is a placeholder that compiles without external dependencies.
/// Full HTTP support requires the `reqwest` crate.
#[cfg(not(feature = "http"))]
pub struct HttpModel {
config: ExternalModelConfig,
metadata: ExternalModelMetadata,
}
#[cfg(not(feature = "http"))]
impl HttpModel {
/// Creates a new HTTP model (stub - returns error without http feature).
pub fn new(_config: ExternalModelConfig) -> Result<Self, ExternalModelError> {
Err(ExternalModelError::NotInitialized)
}
/// Creates with mock data for testing.
pub fn new_mock(
config: ExternalModelConfig,
metadata: ExternalModelMetadata,
) -> Result<Self, ExternalModelError> {
Ok(Self { config, metadata })
}
}
#[cfg(not(feature = "http"))]
impl ExternalModel for HttpModel {
fn id(&self) -> &str {
&self.config.id
}
fn n_inputs(&self) -> usize {
self.config.n_inputs
}
fn n_outputs(&self) -> usize {
self.config.n_outputs
}
fn compute(&self, _inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError> {
Ok(vec![0.0; self.config.n_outputs])
}
fn jacobian(&self, _inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError> {
Ok(vec![0.0; self.config.n_inputs * self.config.n_outputs])
}
fn metadata(&self) -> ExternalModelMetadata {
self.metadata.clone()
}
}
/// Thread-safe wrapper for external models.
///
/// This wrapper ensures safe concurrent access to external models,
/// which may not be thread-safe themselves.
pub struct ThreadSafeExternalModel {
inner: Arc<dyn ExternalModel>,
}
impl ThreadSafeExternalModel {
/// Creates a new thread-safe wrapper.
pub fn new(model: impl ExternalModel + 'static) -> Self {
Self {
inner: Arc::new(model),
}
}
/// Creates from an existing Arc.
pub fn from_arc(model: Arc<dyn ExternalModel>) -> Self {
Self { inner: model }
}
/// Returns a reference to the inner model.
pub fn inner(&self) -> &dyn ExternalModel {
self.inner.as_ref()
}
}
impl Clone for ThreadSafeExternalModel {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl std::fmt::Debug for ThreadSafeExternalModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ThreadSafeExternalModel")
.field("id", &self.inner.id())
.finish()
}
}
/// Mock external model for testing.
#[derive(Debug, Clone)]
pub struct MockExternalModel {
id: String,
n_inputs: usize,
n_outputs: usize,
compute_fn: fn(&[f64]) -> Vec<f64>,
}
impl MockExternalModel {
/// Creates a new mock model.
pub fn new(
id: impl Into<String>,
n_inputs: usize,
n_outputs: usize,
compute_fn: fn(&[f64]) -> Vec<f64>,
) -> Self {
Self {
id: id.into(),
n_inputs,
n_outputs,
compute_fn,
}
}
/// Creates a simple linear model: y = x
pub fn linear_passthrough(n: usize) -> Self {
Self::new("linear_passthrough", n, n, |x| x.to_vec())
}
/// Creates a model that doubles inputs.
pub fn doubler(n: usize) -> Self {
Self::new("doubler", n, n, |x| x.iter().map(|v| v * 2.0).collect())
}
}
impl ExternalModel for MockExternalModel {
fn id(&self) -> &str {
&self.id
}
fn n_inputs(&self) -> usize {
self.n_inputs
}
fn n_outputs(&self) -> usize {
self.n_outputs
}
fn compute(&self, inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError> {
if inputs.len() != self.n_inputs {
return Err(ExternalModelError::InvalidInputDimensions {
expected: self.n_inputs,
actual: inputs.len(),
});
}
Ok((self.compute_fn)(inputs))
}
fn jacobian(&self, inputs: &[f64]) -> Result<Vec<f64>, ExternalModelError> {
// Default: finite difference approximation
let h = 1e-6;
let mut jacobian = vec![0.0; self.n_outputs * self.n_inputs];
for j in 0..self.n_inputs {
let mut inputs_plus = inputs.to_vec();
let mut inputs_minus = inputs.to_vec();
inputs_plus[j] += h;
inputs_minus[j] -= h;
let y_plus = self.compute(&inputs_plus)?;
let y_minus = self.compute(&inputs_minus)?;
for i in 0..self.n_outputs {
jacobian[i * self.n_inputs + j] = (y_plus[i] - y_minus[i]) / (2.0 * h);
}
}
Ok(jacobian)
}
fn metadata(&self) -> ExternalModelMetadata {
ExternalModelMetadata {
name: self.id.clone(),
version: "1.0.0".to_string(),
description: Some("Mock external model for testing".to_string()),
input_names: (0..self.n_inputs).map(|i| format!("input_{}", i)).collect(),
output_names: (0..self.n_outputs)
.map(|i| format!("output_{}", i))
.collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_external_model_compute() {
let model = MockExternalModel::doubler(3);
let result = model.compute(&[1.0, 2.0, 3.0]).unwrap();
assert_eq!(result, vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_mock_external_model_dimensions() {
let model = MockExternalModel::doubler(3);
assert_eq!(model.n_inputs(), 3);
assert_eq!(model.n_outputs(), 3);
}
#[test]
fn test_mock_external_model_invalid_input() {
let model = MockExternalModel::doubler(3);
let result = model.compute(&[1.0, 2.0]);
assert!(result.is_err());
}
#[test]
fn test_mock_external_model_jacobian() {
let model = MockExternalModel::doubler(2);
let jac = model.jacobian(&[1.0, 2.0]).unwrap();
// Jacobian of y = 2x should be [[2, 0], [0, 2]]
assert!((jac[0] - 2.0).abs() < 0.01);
assert!((jac[1] - 0.0).abs() < 0.01);
assert!((jac[2] - 0.0).abs() < 0.01);
assert!((jac[3] - 2.0).abs() < 0.01);
}
#[test]
fn test_thread_safe_wrapper() {
let model = MockExternalModel::doubler(2);
let wrapped = ThreadSafeExternalModel::new(model);
let result = wrapped.inner().compute(&[1.0, 2.0]).unwrap();
assert_eq!(result, vec![2.0, 4.0]);
}
#[test]
fn test_thread_safe_clone() {
let model = MockExternalModel::doubler(2);
let wrapped = ThreadSafeExternalModel::new(model);
let cloned = wrapped.clone();
assert_eq!(wrapped.inner().id(), cloned.inner().id());
}
#[test]
fn test_external_model_metadata() {
let model = MockExternalModel::doubler(2);
let meta = model.metadata();
assert_eq!(meta.name, "doubler");
assert_eq!(meta.version, "1.0.0");
assert_eq!(meta.input_names, vec!["input_0", "input_1"]);
assert_eq!(meta.output_names, vec!["output_0", "output_1"]);
}
#[test]
fn test_linear_passthrough_model() {
let model = MockExternalModel::linear_passthrough(3);
let result = model.compute(&[1.0, 2.0, 3.0]).unwrap();
assert_eq!(result, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_external_model_config() {
let config = ExternalModelConfig {
id: "test_model".to_string(),
model_type: ExternalModelType::Http {
base_url: "http://localhost:8080".to_string(),
api_key: Some("secret".to_string()),
},
n_inputs: 4,
n_outputs: 2,
timeout_ms: 3000,
};
assert_eq!(config.id, "test_model");
assert_eq!(config.n_inputs, 4);
assert_eq!(config.n_outputs, 2);
assert_eq!(config.timeout_ms, 3000);
}
#[test]
fn test_error_conversion() {
let err = ExternalModelError::ComputationFailed("test error".to_string());
let component_err: ComponentError = err.into();
match component_err {
ComponentError::InvalidState(msg) => {
assert!(msg.contains("External model error"));
}
_ => panic!("Expected InvalidState error"),
}
}
}