diff --git a/Cargo.lock b/Cargo.lock index 12b8273..a0d19e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,15 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +[[package]] +name = "approx" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278" +dependencies = [ + "num-traits", +] + [[package]] name = "approx" version = "0.5.1" @@ -224,6 +233,8 @@ dependencies = [ "anyhow", "cfg-if", "nalgebra", + "ndarray", + "ndarray-linalg", "num-complex", "num-integer", "num-traits", @@ -667,6 +678,27 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cauchy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ff11ddd2af3b5e80dd0297fee6e56ac038d9bdc549573cdb51bd6d2efe7f05e" +dependencies = [ + "num-complex", + "num-traits", + "rand", + "serde", +] + +[[package]] +name = "cblas-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" +dependencies = [ + "libc", +] + [[package]] name = "cc" version = "1.0.79" @@ -1565,12 +1597,45 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "katexit" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb1304c448ce2c207c2298a34bc476ce7ae47f63c23fa2b498583b26be9bc88c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "khronos_api" version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" +[[package]] +name = "lapack-sys" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "447f56c85fb410a7a3d36701b2153c1018b1d2b908c5fbaf01c1b04fac33bcbe" +dependencies = [ + "libc", +] + +[[package]] +name = "lax" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f96a229d9557112e574164f8024ce703625ad9f88a90964c1780809358e53da" +dependencies = [ + "cauchy", + "katexit", + "lapack-sys", + "num-traits", + "thiserror", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1613,9 +1678,9 @@ checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" [[package]] name = "linux-raw-sys" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b64f40e5e03e0d54f03845c8197d0291253cdbedfb1cb46b13c2c117554a9f4c" +checksum = "ece97ea872ece730aed82664c424eb4c8291e1ff2480247ccf7409044bc6479f" [[package]] name = "lock_api" @@ -1731,12 +1796,15 @@ version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d68d47bba83f9e2006d117a9a33af1524e655516b8919caac694427a6fb1e511" dependencies = [ - "approx", + "approx 0.5.1", "matrixmultiply", "nalgebra-macros", "num-complex", "num-rational", "num-traits", + "rand", + "rand_distr", + "serde", "simba", "typenum", ] @@ -1762,6 +1830,39 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "approx 0.4.0", + "cblas-sys", + "libc", + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "serde", +] + +[[package]] +name = "ndarray-linalg" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b0e8dda0c941b64a85c5deb2b3e0144aca87aced64678adfc23eacea6d2cc42" +dependencies = [ + "cauchy", + "katexit", + "lax", + "ndarray", + "num-complex", + "num-traits", + "rand", + "thiserror", +] + [[package]] name = "ndk" version = "0.7.0" @@ -1852,6 +1953,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" dependencies = [ "num-traits", + "rand", + "serde", ] [[package]] @@ -2218,6 +2321,16 @@ dependencies = [ "serde", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -2325,9 +2438,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.18" +version = "0.37.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bbfc1d1c7c40c01715f47d71444744a81669ca84e8b63e25a55e169b1f86433" +checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" dependencies = [ "bitflags", "errno", @@ -2469,7 +2582,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" dependencies = [ - "approx", + "approx 0.5.1", "num-complex", "num-traits", "paste", @@ -2496,6 +2609,8 @@ dependencies = [ "nalgebra", "nalgebra-sparse", "nohash-hasher", + "rand", + "rand_distr", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index bc4e26a..63cbd48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,14 +7,16 @@ edition = "2021" [dependencies] argmin = { version = "0.8.1", features = [] } -argmin-math = { version = "0.3.0", features = ["nalgebra_0_32"] } +argmin-math = { version = "0.3.0", features = ["latest_all"] } bevy_ecs = "0.10.1" -eframe = { version = "0.21.3", features = [] } +eframe = { version = "0.21.3" } indexmap = "1.8.1" levenberg-marquardt = "0.13.0" -nalgebra = "0.32.0" +nalgebra = { version = "0.32.0", features = ["rand"] } nalgebra-sparse = "0.9.0" nohash-hasher = "0.2.0" +rand = "0.8.5" +rand_distr = "0.4.3" [dev-dependencies] criterion = "0.4.0" diff --git a/src/geometry.rs b/src/geometry.rs index d26d436..fce2890 100644 --- a/src/geometry.rs +++ b/src/geometry.rs @@ -1,11 +1,11 @@ mod var; -use std::ops::DerefMut; +use std::ops::{Deref, DerefMut}; use bevy_ecs::{ entity::Entity, - prelude::Component, - query::QueryEntityError, + prelude::{Bundle, Component}, + query::{QueryEntityError, With}, system::{Commands, Query, SystemParam}, }; @@ -13,7 +13,7 @@ pub use self::var::*; pub type Scalar = f64; -#[derive(Clone, Debug, Component)] +#[derive(Default, Clone, Debug, Component)] pub struct Point { pub x: V, pub y: V, @@ -34,7 +34,7 @@ impl From<(V, V)> for Point { pub type PointPos = Point; pub type PointId = Entity; -#[derive(Clone, Debug, Component)] +#[derive(Clone, Debug, Default, Component)] // TODO: segment, ray and full line // TODO: what if start and end are coincident? pub struct Line

{ @@ -60,35 +60,45 @@ pub struct Circle { // pub type CirclePos = CircleEntity; +#[derive(Bundle)] +pub struct PointBundle { + pub point: Point, + pub computed_pos: ComputedPointPos, +} + +impl PointBundle { + pub fn new(x: VarId, y: VarId) -> Self { + Self { + point: Point::new(x, y), + computed_pos: Default::default(), + } + } +} + pub fn insert_point_at(commands: &mut Commands, point: impl Into) -> PointId { let point = point.into(); let x = commands.spawn(Var::new_free(point.x)).id(); let y = commands.spawn(Var::new_free(point.y)).id(); - commands.spawn(Point::new(x, y)).id() + commands.spawn(PointBundle::new(x, y)).id() } -// TODO: figure out generic for this -#[derive(SystemParam)] -pub struct PointPosQuery<'w, 's> { - vars: Query<'w, 's, &'static Var>, +#[derive(Component, Default)] +pub struct ComputedPointPos(Var); + +impl Deref for ComputedPointPos { + type Target = Var; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -// type PointPosQuery<'w, 's> = PointPosQueryImpl<'w, 's, &'static Var>; -// type PointPosQueryMut<'w, 's> = PointPosQueryImpl<'w, 's, &'static mut Var>; - -// impl<'w, 's, Q> PointPosQueryImpl<'w, 's, Q> -// where -// Q: WorldQuery + 'static, -// for<'w2, 's2> <::ReadOnlyFetch as Fetch<'w2, 's2>>::Item: Deref, -impl<'w, 's> PointPosQuery<'w, 's> { - pub fn try_get(&self, point: &Point) -> Result, QueryEntityError> { - let x = self.vars.get(point.x)?; - let y = self.vars.get(point.y)?; - Ok(x.merge(*y, PointPos::new)) - } - pub fn get(&self, point: &Point) -> Var { - self.try_get(point).unwrap() - } +pub fn update_point_pos(mut points: Query<(&Point, &mut ComputedPointPos)>, vars: Query<&Var>) { + points.for_each_mut(|(point, mut computed_pos)| { + let x = vars.get(point.x).unwrap(); + let y = vars.get(point.y).unwrap(); + computed_pos.0 = x.merge(*y, PointPos::new); + }); } #[derive(SystemParam)] @@ -116,24 +126,51 @@ impl<'w, 's> PointPosQueryMut<'w, 's> } } -#[derive(SystemParam)] -pub struct LinePosQuery<'w, 's> { - point_pos: PointPosQuery<'w, 's>, - points: Query<'w, 's, &'static Point>, -} +#[derive(Default, Component)] +pub struct ComputedLinePos(Var); -impl<'w, 's> LinePosQuery<'w, 's> { - pub fn get(&self, line: &Line) -> Var { - self.try_get(line).unwrap() - } +impl Deref for ComputedLinePos { + type Target = Var; - pub fn try_get(&self, line: &Line) -> Result, QueryEntityError> { - // TODO: error handling? - let start = self.points.get(line.start)?; - let start = self.point_pos.try_get(start)?; - - let end = self.points.get(line.end)?; - let end = self.point_pos.try_get(end)?; - Ok(start.merge(end, LinePos::new)) + fn deref(&self) -> &Self::Target { + &self.0 } } + +pub fn update_line_pos( + mut lines: Query<(&Line, &mut ComputedLinePos)>, + points: Query<&ComputedPointPos>, +) { + lines.for_each_mut(|(line, mut computed_pos)| { + let Ok(start) = points.get(line.start) else { return }; + let Ok(end) = points.get(line.end) else { return} ; + computed_pos.0 = (*start).clone().merge((*end).clone(), LinePos::new); + }); +} + +#[derive(Bundle)] +pub struct LineBundle { + pub line: Line, + pub computed_pos: ComputedLinePos, +} + +impl LineBundle { + pub fn new(start: PointId, end: PointId) -> Self { + Self { + line: Line::new(start, end), + computed_pos: Default::default(), + } + } +} + +pub fn remove_dangling_lines( + mut lines: Query<(Entity, &Line)>, + points: Query<(), With>, + mut commands: Commands, +) { + lines.for_each_mut(|(id, line)| { + if !points.contains(line.start) || !points.contains(line.end) { + commands.entity(id).despawn(); + } + }); +} diff --git a/src/geometry/var.rs b/src/geometry/var.rs index 2f477a1..16fac2c 100644 --- a/src/geometry/var.rs +++ b/src/geometry/var.rs @@ -1,4 +1,4 @@ -use bevy_ecs::{prelude::Component, entity::Entity}; +use bevy_ecs::{entity::Entity, prelude::Component}; use super::Scalar; @@ -28,6 +28,12 @@ pub struct Var { pub status: VarStatus, } +impl Default for Var { + fn default() -> Self { + Self::new_free(T::default()) + } +} + impl Var { pub fn new_free(value: T) -> Self { Self { diff --git a/src/main.rs b/src/main.rs index f053af8..4bbad3b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,17 +10,12 @@ use eframe::{ emath::{Pos2, Rect, RectTransform, Vec2}, epaint::{Color32, Hsva, Stroke}, }; -use geometry::{Line, LinePosQuery, Point, PointId, PointPos, PointPosQuery, PointPosQueryMut}; +use geometry::{Line, LineBundle, Point, PointId, PointPos, PointPosQueryMut}; mod geometry; -mod optimization; +pub mod optimization; mod relations; -fn main() { - let options = eframe::NativeOptions::default(); - eframe::run_native("sketchrs", options, Box::new(|_cc| Box::::default())).unwrap(); -} - #[derive(Clone, Copy, PartialEq, Resource)] enum Tool { Select, @@ -103,14 +98,12 @@ const POINT_RADIUS: f32 = 3.0; fn update_hover_point( response: Res, to_screen: Res, - point_pos: PointPosQuery, - points: Query<(PointId, &Point)>, + points: Query<(PointId, &geometry::ComputedPointPos)>, mut commands: Commands, ) { - points.for_each(|(id, point)| { + points.for_each(|(id, pos)| { let hovered = if let Some(hover_pos) = response.hover_pos() { - let pos = point_pos.get(point); - let center = to_screen.transform_pos(&pos); + let center = to_screen.transform_pos(pos); (hover_pos - center).length() < (POINT_RADIUS * 3.) } else { @@ -128,13 +121,11 @@ fn update_hover_point( fn update_hover_line( response: Res, to_screen: Res, - lines: Query<(Entity, &Line)>, - line_pos: LinePosQuery, + lines: Query<(Entity, &geometry::ComputedLinePos)>, mut commands: Commands, ) { - lines.for_each(|(id, line)| { + lines.for_each(|(id, pos)| { let hovered = if let Some(hover_pos) = response.hover_pos() { - let pos = line_pos.get(line); let points = [ to_screen.transform_pos(&pos.start), to_screen.transform_pos(&pos.end), @@ -172,14 +163,32 @@ fn select_tool( }); if response.clicked() { - selected.for_each(|selected| { - commands.entity(selected).remove::(); - }); + if !response.ctx.input(|input| input.modifiers.shift) { + selected.for_each(|selected| { + commands.entity(selected).remove::(); + }); + } // TODO: choose which to select if let Some(hovered) = hovered.iter().next() { commands.entity(hovered).insert(Selected); } } + + if response + .ctx + .input(|input| input.key_pressed(egui::Key::Escape)) + { + selected.for_each(|selected| { + commands.entity(selected).remove::(); + }); + } else if response + .ctx + .input(|input| input.key_pressed(egui::Key::Delete)) + { + selected.for_each(|selected| { + commands.entity(selected).despawn(); + }); + } } #[derive(Default)] @@ -259,8 +268,7 @@ fn add_line_tool( to_screen: Res, painter: ResMut, hovered: Query, With)>, - selected: Query<(Entity, &Point), With>, - point_pos: PointPosQuery, + selected: Query<(Entity, &geometry::ComputedPointPos), With>, mut commands: Commands, ) { let hover_pos = response.hover_pos().unwrap(); @@ -276,9 +284,8 @@ fn add_line_tool( }; commands.entity(point_id).insert(Selected); } - (Some((_, start_point)), false) => { - let start_point_pos = point_pos.get(start_point); - let points = [to_screen.transform_pos(&start_point_pos), hover_pos]; + (Some((_, start_point_pos)), false) => { + let points = [to_screen.transform_pos(start_point_pos), hover_pos]; let stroke = Stroke::new(2.0, Color32::DARK_GRAY); @@ -293,7 +300,7 @@ fn add_line_tool( .next() .unwrap_or_else(|| add_point(&mut commands, hover_pos, &to_screen)); - let line = Line::new(start_point_id, end_point); + let line = LineBundle::new(start_point_id, end_point); commands.spawn(line); selected.for_each(|(selected_id, _)| { @@ -340,13 +347,11 @@ fn is_tool_active(tool: Tool) -> impl System + ReadOnlySyst fn paint_lines( to_screen: Res, painter: ResMut, - lines: Query<(Entity, &Line)>, + lines: Query<(Entity, &geometry::ComputedLinePos)>, hovered: Query<(), With>, selected: Query<(), With>, - line_pos: LinePosQuery, ) { - lines.for_each(|(id, line)| { - let pos = line_pos.get(line); + lines.for_each(|(id, pos)| { let points = [ to_screen.transform_pos(&pos.start), to_screen.transform_pos(&pos.end), @@ -367,15 +372,12 @@ fn paint_lines( fn paint_points( to_screen: Res, painter: ResMut, - points: Query<(Entity, &Point)>, + points: Query<(Entity, &geometry::ComputedPointPos)>, hovered: Query<(), With>, selected: Query<(), With>, - point_pos: PointPosQuery, ) { - points.for_each(|(id, point)| { - let pos = point_pos.get(point); - - let center = to_screen.transform_pos(&pos); + points.for_each(|(id, pos)| { + let center = to_screen.transform_pos(pos); let color = color_for_var_status(pos.status); let stroke = if selected.contains(id) || hovered.contains(id) { @@ -392,7 +394,7 @@ fn init(mut commands: Commands) { let p1 = geometry::insert_point_at(&mut commands, (10., 30.)); let p2 = geometry::insert_point_at(&mut commands, (-20., 15.)); geometry::insert_point_at(&mut commands, (0., -10.)); - commands.spawn(Line::new(p1, p2)); + commands.spawn(LineBundle::new(p1, p2)); } #[derive(Resource)] @@ -406,22 +408,6 @@ impl Deref for ContextRes { } } -#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)] -#[system_set(base)] -pub enum ShowEntitiesStage { - Input, - Tools, - Paint, -} - -#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)] -#[system_set(base)] -pub enum ScheduleSet { - Input, - Tools, - Paint, -} - fn prepare(ctx: Res) { ctx.request_repaint(); ctx.set_visuals(egui::Visuals::dark()); @@ -454,13 +440,36 @@ fn toolbar( struct ShowEntitiesSchedule(Schedule); +#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)] +#[system_set(base)] +pub enum ShowEntitiesStage { + Update, + Input, + Tools, + PostTools, + Paint, +} + impl Default for ShowEntitiesSchedule { fn default() -> Self { let mut schedule = Schedule::new(); - schedule - .configure_sets((ScheduleSet::Input, ScheduleSet::Tools, ScheduleSet::Paint).chain()); - schedule - .add_systems((update_hover_point, update_hover_line).in_base_set(ScheduleSet::Input)); + schedule.configure_sets( + ( + ShowEntitiesStage::Update, + ShowEntitiesStage::Input, + ShowEntitiesStage::Tools, + ShowEntitiesStage::PostTools, + ShowEntitiesStage::Paint, + ) + .chain(), + ); + schedule.add_systems( + (geometry::update_point_pos, geometry::update_line_pos) + .in_base_set(ShowEntitiesStage::Update), + ); + schedule.add_systems( + (update_hover_point, update_hover_line).in_base_set(ShowEntitiesStage::Input), + ); schedule.add_systems( ( select_tool.run_if(is_tool_active(Tool::Select)), @@ -470,10 +479,14 @@ impl Default for ShowEntitiesSchedule { add_relation_tool.run_if(is_tool_active(Tool::AddRelation)), ) .distributive_run_if(is_hovered) - .in_base_set(ScheduleSet::Tools), + .in_base_set(ShowEntitiesStage::Tools), ); + schedule + .add_system(geometry::remove_dangling_lines.in_base_set(ShowEntitiesStage::PostTools)); schedule.add_systems( - (paint_lines, paint_points.after(paint_lines)).in_base_set(ScheduleSet::Paint), + (paint_lines, paint_points) + .chain() + .in_base_set(ShowEntitiesStage::Paint), ); Self(schedule) } @@ -513,6 +526,53 @@ struct SelectableEntity<'a> { line: Option<&'a Line>, } +fn side_panel_ui( + ui: &mut egui::Ui, + selected: Query>, + tool: Res, +) { + let tool = *tool; + ui.vertical(|ui| match tool { + Tool::Select => { + let mut count = 0; + selected.for_each(|sel| { + count += 1; + if sel.point.is_some() { + ui.label(format!("Selected point {}", sel.id.index())); + } else if sel.line.is_some() { + ui.label(format!("Selected line {}", sel.id.index())); + } + }); + if count == 0 { + ui.label("Nothing selected"); + } + } + Tool::Move => { + let mut count = 0; + selected.for_each(|sel| { + count += 1; + if sel.point.is_some() { + ui.label(format!("Selected point {}", sel.id.index())); + } else if sel.line.is_some() { + ui.label(format!("Selected line {}", sel.id.index())); + } + }); + if count == 0 { + ui.label("Nothing selected"); + } + } + Tool::AddPoint => { + ui.label("Click to add a point"); + } + Tool::AddLine => { + ui.label("Click to add a line"); + } + Tool::AddRelation => { + ui.label("Click to add a relation"); + } + }); +} + fn side_panel( ctx: Res, selected: Query>, @@ -522,34 +582,7 @@ fn side_panel( .resizable(true) .default_width(150.0) .width_range(80.0..=200.0) - .show(&ctx, |ui| { - let tool = *tool; - ui.vertical(|ui| match tool { - Tool::Select => { - let mut count = 0; - selected.for_each(|sel| { - count += 1; - if sel.point.is_some() { - ui.label(format!("Selected point {}", sel.id.index())); - } else if sel.line.is_some() { - ui.label(format!("Selected line {}", sel.id.index())); - } - }); - if count == 0 { - ui.label("No selection"); - } - } - Tool::AddRelation => { - if let Some(_first) = selected.iter().next() { - } else { - ui.label("Select an entity to add a relation"); - } - } - _ => { - ui.label(":)"); - } - }); - }); + .show(&ctx, |ui| side_panel_ui(ui, selected, tool)); } fn bottom_panel(ctx: Res) { @@ -599,3 +632,8 @@ impl eframe::App for MyApp { self.world.remove_resource::(); } } + +fn main() { + let options = eframe::NativeOptions::default(); + eframe::run_native("sketchrs", options, Box::new(|_cc| Box::::default())).unwrap(); +} diff --git a/src/optimization.rs b/src/optimization.rs index 1decd82..55bef5a 100644 --- a/src/optimization.rs +++ b/src/optimization.rs @@ -1,122 +1,449 @@ -use std::mem::MaybeUninit; +use std::{borrow::Borrow, fmt::Debug}; -use nalgebra::{DMatrix, DVector, DimName}; +use argmin_math; +use levenberg_marquardt::{self, TerminationReason}; +use nalgebra::{ArrayStorage, Const, Dyn, VecStorage}; use nalgebra_sparse::CooMatrix; use crate::geometry::Scalar; type SVector = nalgebra::SVector; +type SMatrix = nalgebra::SMatrix; +type DVector = nalgebra::DVector; +type DMatrix = nalgebra::DMatrix; -trait SResidual { - fn apply(&self, x: SVector) -> Scalar; - fn grad(&self, x: SVector) -> SVector; +trait SResBlock: Clone + Debug { + fn evaluate( + &self, + x: SVector, + residual: Option<&mut SVector>, + jacobian: Option<&mut SMatrix>, + ) { + if let Some(residual) = residual { + *residual = self.residual(x); + } + if let Some(jacobian) = jacobian { + *jacobian = self.jacobian(x); + } + } + + fn residual(&self, x: SVector) -> SVector { + let mut residual = SVector::::zeros(); + self.evaluate(x, Some(&mut residual), None); + residual + } + fn jacobian(&self, x: SVector) -> SMatrix { + let mut jacobian = SMatrix::::zeros(); + self.evaluate(x, None, Some(&mut jacobian)); + jacobian + } + + // Probably should move to own trait + fn into_indexed( + self, + in_indices: [Index; NIN], + out_indices: [Index; NOUT], + ) -> Box + where + Self: Sized + 'static, + { + Box::new(IndexedResBlockImpl { + residual: self, + in_indices, + out_indices, + }) + } } -struct PointPointDistance(Scalar); - fn square(x: Scalar) -> Scalar { x * x } -impl SResidual<4> for PointPointDistance { - fn apply(&self, x: SVector<4>) -> Scalar { - let [[x1, y1, x2, y2]] = x.data.0; - // x.generic_slice() - square(x2 - x1) + square(y2 - y1) - square(self.0) - // (square(x2 - x1) + square(y2 - y1)).sqrt() - self.0 - } +#[derive(Clone, Debug)] +pub struct Equal1D; - fn grad(&self, x: SVector<4>) -> SVector<4> { - let [[x1, y1, x2, y2]] = x.data.0; - SVector::<4>::new( - 2. * (x1 - x2), - 2. * (y1 - y2), - 2. * (x2 - x1), - 2. * (y2 - x1), - ) +impl SResBlock<2> for Equal1D { + fn evaluate( + &self, + x: SVector<2>, + residual: Option<&mut SVector<1>>, + jacobian: Option<&mut SMatrix<2, 1>>, + ) { + let [x1, x2] = *x.as_ref(); + let dx = x2 - x1; + if let Some(residual) = residual { + residual[0] = square(dx); + } + if let Some(jacobian) = jacobian { + jacobian[0] = -2. * dx; + jacobian[1] = 2. * dx; + } } } -trait IndexedResidual { - fn apply(&self, x: &[Scalar]) -> Scalar; - fn grad_indices(&self) -> &[usize]; - fn grad_append(&self, x: &[Scalar], grad: &mut Vec); +#[derive(Clone, Debug)] +pub struct AxisDistance(Scalar); + +impl SResBlock<2> for AxisDistance { + fn evaluate( + &self, + x: SVector<2>, + residual: Option<&mut SVector<1>>, + jacobian: Option<&mut SMatrix<2, 1>>, + ) { + let [x1, x2] = *x.as_ref(); + let dx = x2 - x1; + let r = self.0; + if let Some(residual) = residual { + residual[0] = square(dx) - square(r); + } + if let Some(jacobian) = jacobian { + jacobian[0] = -2. * dx; + jacobian[1] = 2. * dx; + } + } } -struct MappedResidual { +#[derive(Clone, Debug)] +pub struct PointPointDistance(Scalar); + +impl SResBlock<4> for PointPointDistance { + fn evaluate( + &self, + x: SVector<4>, + residual: Option<&mut SVector<1>>, + jacobian: Option<&mut SMatrix<4, 1>>, + ) { + let [x1, y1, x2, y2] = *x.as_ref(); + let (dx, dy) = (x2 - x1, y2 - y1); + let r = self.0; + if let Some(residual) = residual { + residual[0] = square(dx) + square(dy) - square(r); + } + if let Some(jacobian) = jacobian { + jacobian[0] = -2. * dx; + jacobian[1] = -2. * dy; + jacobian[2] = 2. * dx; + jacobian[3] = 2. * dy; + } + } +} + +#[derive(Clone, Debug)] +pub struct PointDistance(SVector<2>, Scalar); + +impl SResBlock<2> for PointDistance { + fn evaluate( + &self, + x: SVector<2>, + residual: Option<&mut SVector<1>>, + jacobian: Option<&mut SMatrix<2, 1>>, + ) { + let d = x - self.0; + let r = self.1; + if let Some(residual) = residual { + residual[0] = d.norm_squared() - square(r); + } + if let Some(jacobian) = jacobian { + jacobian.copy_from(&(2. * d)); + } + } +} + +#[derive(Clone, Debug)] +pub struct Collinear; + +impl SResBlock<6> for Collinear { + fn evaluate( + &self, + x: SVector<6>, + residual: Option<&mut SVector<1>>, + jacobian: Option<&mut SMatrix<6, 1>>, + ) { + let [x1, y1, x2, y2, x3, y3] = *x.as_ref(); + let (dx21, dy21) = (x2 - x1, y2 - y1); + let (dx32, dy32) = (x3 - x2, y3 - y2); + if let Some(residual) = residual { + // dy21 / dx21 = dy32 / dx32 + // dy21 * dx32 = dy32 * dx21 + residual[0] = dy21 * dx32 - dy32 * dx21; + } + if let Some(jacobian) = jacobian { + jacobian[0] = dy32; + jacobian[1] = -dx32; + jacobian[2] = -dy32; + jacobian[3] = -dx21; + jacobian[4] = dy21; + jacobian[5] = -dx21; + } + } +} + +type Index = u32; + +trait ScatterResBlock: Debug { + fn shape(&self) -> (usize, usize); + fn apply_into(&self, x: &[Scalar], residuals: &mut [Scalar]); + fn jacobian_into(&self, x: &[Scalar], jacobian: &mut [Scalar]); + fn evaluate_into(&self, x: &[Scalar], residuals: &mut [Scalar], jacobian: &mut [Scalar]) { + self.apply_into(x, residuals); + self.jacobian_into(x, jacobian); + } +} + +#[derive(Clone, Debug)] +struct ScatterResBlockImpl { residual: R, - indices: [usize; N], + in_indices: [Index; NIN], + residual_indices: [Index; NOUT], + jacobian_indices: [[Index; NIN]; NOUT], } -impl MappedResidual { - fn gather_svector(&self, x: &[Scalar]) -> SVector { - SVector::from_fn(|_, i| x[self.indices[i]]) +impl ScatterResBlockImpl { + fn gather_svector(&self, x: &[Scalar]) -> SVector { + SVector::from_fn(|i, _| x[self.in_indices[i] as usize]) } } -impl> IndexedResidual for MappedResidual { - fn apply(&self, x: &[Scalar]) -> Scalar { - self.residual.apply(self.gather_svector(x)) +impl ScatterResBlock for ScatterResBlockImpl +where + R: SResBlock, +{ + fn shape(&self) -> (usize, usize) { + (NIN, NOUT) } - fn grad_indices(&self) -> &[usize] { - &self.indices + fn apply_into(&self, x: &[Scalar], residuals: &mut [Scalar]) { + // Front load bounds check + let mut inbounds = true; + for i in self.residual_indices.iter().cloned() { + inbounds &= (i as usize) < residuals.len(); + } + assert!(inbounds); + + let res = self.residual.residual(self.gather_svector(x)); + for (i, r) in self + .residual_indices + .iter() + .cloned() + .zip(res.borrow().iter()) + { + residuals[i as usize] = *r; + } } - fn grad_append(&self, x: &[Scalar], mut grad: &mut Vec) { - let gvals = self.residual.grad(self.gather_svector(x)); - grad.extend(gvals.as_slice()); + fn jacobian_into(&self, x: &[Scalar], jacobian: &mut [Scalar]) { + // Front load bounds check + let mut inbounds = true; + for ir in self.jacobian_indices.iter() { + for i in ir.iter().cloned() { + inbounds &= (i as usize) < jacobian.len(); + } + } + assert!(inbounds); + + let jac = self.residual.jacobian(self.gather_svector(x)); + for (ir, r) in self.jacobian_indices.iter().zip(jac.column_iter()) { + for (i, r) in ir.iter().zip(r.iter()) { + jacobian[*i as usize] = *r; + } + } + } + + fn evaluate_into(&self, x: &[Scalar], residuals: &mut [Scalar], jacobian: &mut [Scalar]) { + // Front load bounds check + let mut inbounds = true; + for i in self.residual_indices.iter().cloned() { + inbounds &= (i as usize) < residuals.len(); + } + for ir in self.jacobian_indices.iter() { + for i in ir.iter().cloned() { + inbounds &= (i as usize) < jacobian.len(); + } + } + assert!(inbounds); + + let mut res = SVector::::zeros(); + let mut jac = SMatrix::::zeros(); + self.residual + .evaluate(self.gather_svector(x), Some(&mut res), Some(&mut jac)); + for (i, r) in self.residual_indices.iter().cloned().zip(res.iter()) { + residuals[i as usize] = *r; + } + for (ir, r) in self.jacobian_indices.iter().zip(jac.column_iter()) { + for (i, r) in ir.iter().zip(r.iter()) { + jacobian[*i as usize] = *r; + } + } } } +trait IndexedResBlock { + fn shape(&self) -> (usize, usize); + fn in_indices(&self) -> &[Index]; + fn out_indices(&self) -> &[Index]; + fn into_scatter(self: Box, total_shape: (usize, usize)) -> Box; +} + +struct IndexedResBlockImpl { + residual: R, + in_indices: [Index; NIN], + out_indices: [Index; NOUT], +} + +impl IndexedResBlock for IndexedResBlockImpl +where + R: SResBlock + 'static, +{ + fn shape(&self) -> (usize, usize) { + (NIN, NOUT) + } + + fn in_indices(&self) -> &[Index] { + &self.in_indices + } + + fn out_indices(&self) -> &[Index] { + &self.out_indices + } + + fn into_scatter(self: Box, total_shape: (usize, usize)) -> Box { + let mut jacobian_indices = [[0; NIN]; NOUT]; + for (i, jr) in jacobian_indices.iter_mut().enumerate() { + for (j, jj) in jr.iter_mut().enumerate() { + // *jj = self.in_indices[j] + (self.out_indices[i] * total_shape.0 as u32) as Index; + *jj = self.out_indices[i] + (self.in_indices[j] * total_shape.0 as u32) as Index; + } + } + Box::new(ScatterResBlockImpl { + residual: self.residual, + in_indices: self.in_indices, + residual_indices: self.out_indices, + jacobian_indices, + }) + } +} + +#[derive(Debug)] struct Residuals { - residuals: Vec>, + res_blocks: Vec>, + shape: (usize, usize), } impl Residuals { - fn apply_fill(&self, x: &[Scalar], residuals: &mut [Scalar]) { - assert_eq!(self.residuals.len(), residuals.len()); - for (resbox, rval) in self.residuals.iter().zip(residuals) { - *rval = resbox.apply(x); + fn new(res_blocks: Vec>, shape: (usize, usize)) -> Self { + let res_blocks = res_blocks + .into_iter() + .map(|r| r.into_scatter(shape)) + .collect(); + Self { res_blocks, shape } + } + + fn residuals_into>(&self, x: X, residuals: &mut [Scalar]) { + let x = x.as_ref(); + assert!(self.shape.0 == residuals.len()); + for residual in &self.res_blocks { + residual.apply_into(x, residuals); } } - // fn grad_fill(&self, x: &[Scalar], grad: &mut [Scalar]) { - // assert_eq!(self.residuals.len(), grad.len()); - // let mut gradcol = Vec::new(); - // for resbox in self.residuals.iter() { - // let res = resbox.apply(x); - // let indices = resbox.grad_indices(); - // gradcol.resize(indices.len(), 0.); - // resbox.grad_fill(x, grad); - // for i in indices { - // grad[i] = 2 * res * gradcol - // } - // // resbox.grad_fill(x, gradcol.as_mut_slice()); + fn residuals>(&self, x: X) -> Vec { + let mut residuals = vec![0.; self.shape.0]; + self.residuals_into(x, &mut residuals); + residuals + } + + // fn jacobian_into<'a, X: AsRef<[Scalar]>, J: Into>>( + // &self, + // x: X, + // jacobian: J, + // ) { + // let x = x.as_ref(); + // let mut jacobian = jacobian.into(); + // assert!(self.shape == jacobian.shape()); + // for residual in &self.residuals { + // residual.jacobian_into(x, jacobian.as_mut_slice()); // } // } - fn jacobian(&self, x: &[Scalar]) -> CooMatrix { - let gradlen = self - .residuals - .iter() - .map(|res| res.grad_indices().len()) - .sum::(); - let mut row_indices = Vec::with_capacity(gradlen); - let mut col_indices = Vec::with_capacity(gradlen); - let mut grads = Vec::with_capacity(gradlen); - for (j, resbox) in self.residuals.iter().enumerate() { - row_indices.extend(resbox.grad_indices()); - col_indices.extend(std::iter::repeat(j).take(resbox.grad_indices().len())); - resbox.grad_append(x, &mut grads); + fn jacobian>(&self, x: X) -> DMatrix { + // assert!(self.shape == jacobian.shape()); + let x = x.as_ref(); + let mut jacobian = DMatrix::zeros(self.shape.0, self.shape.1); + for residual in &self.res_blocks { + residual.jacobian_into(x, jacobian.as_mut_slice()); } - CooMatrix::try_from_triplets( - x.len(), - self.residuals.len(), - row_indices, - col_indices, - grads, - ) - .unwrap() + jacobian } } + +#[derive(Debug)] +struct Problem { + residuals: Residuals, + x: DVector, +} + +impl levenberg_marquardt::LeastSquaresProblem for Problem { + type ResidualStorage = VecStorage>; + type JacobianStorage = VecStorage; + type ParameterStorage = VecStorage>; + + fn set_params(&mut self, x: &DVector) { + self.x = x.clone(); + } + + fn params(&self) -> DVector { + self.x.clone() + } + + fn residuals(&self) -> Option { + let res = self.residuals.residuals(self.x.as_slice()); + Some(DVector::from(res)) + } + + fn jacobian(&self) -> Option> { + let jac = self.residuals.jacobian(self.x.as_slice()); + Some(jac) + } +} + +pub fn test() { + let lm = levenberg_marquardt::LevenbergMarquardt::::new(); + + let nvar = 4; + let nres = 3; + let residuals = Residuals::new( + vec![ + PointDistance(SVector::<2>::new(1., 0.), 0.75).into_indexed([0, 1], [0]), + PointPointDistance(1.0).into_indexed([0, 1, 2, 3], [2]), + PointDistance(SVector::<2>::new(0., 0.), 1.25).into_indexed([2, 3], [1]), + ], + (nres, nvar), + ); + + let dist = rand_distr::Normal::new(0., 1.0).unwrap(); + let prob = Problem { + residuals, + x: DVector::from_distribution_generic( + Dyn(nvar), + Const::<1>, + &dist, + &mut rand::thread_rng(), + ), + }; + + println!("Problem: {:?}", prob); + + let (mut prob, rep) = lm.minimize(prob); + println!("Minimized: {:?}", prob.x); + println!("Report: {:?}", rep); + + // // perturb the initial guess + // let dist = rand_distr::Normal::new(0., 1.0).unwrap(); + // prob.x += SVector::<2>::from_distribution(&dist, &mut rand::thread_rng()); + // println!("Problem: {:?}", prob); + + // let (mut prob, rep) = lm.minimize(prob); + // println!("Minimized: {:?}", prob.x); + // println!("Report: {:?}", rep); +}