lots of good optimization shit
This commit is contained in:
parent
50ab2d1976
commit
a70db904f5
127
Cargo.lock
generated
127
Cargo.lock
generated
@ -165,6 +165,15 @@ version = "1.0.71"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8"
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
@ -224,6 +233,8 @@ dependencies = [
|
||||
"anyhow",
|
||||
"cfg-if",
|
||||
"nalgebra",
|
||||
"ndarray",
|
||||
"ndarray-linalg",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
@ -667,6 +678,27 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cauchy"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ff11ddd2af3b5e80dd0297fee6e56ac038d9bdc549573cdb51bd6d2efe7f05e"
|
||||
dependencies = [
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cblas-sys"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.79"
|
||||
@ -1565,12 +1597,45 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "katexit"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb1304c448ce2c207c2298a34bc476ce7ae47f63c23fa2b498583b26be9bc88c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "khronos_api"
|
||||
version = "3.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc"
|
||||
|
||||
[[package]]
|
||||
name = "lapack-sys"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "447f56c85fb410a7a3d36701b2153c1018b1d2b908c5fbaf01c1b04fac33bcbe"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lax"
|
||||
version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f96a229d9557112e574164f8024ce703625ad9f88a90964c1780809358e53da"
|
||||
dependencies = [
|
||||
"cauchy",
|
||||
"katexit",
|
||||
"lapack-sys",
|
||||
"num-traits",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
@ -1613,9 +1678,9 @@ checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.3.6"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b64f40e5e03e0d54f03845c8197d0291253cdbedfb1cb46b13c2c117554a9f4c"
|
||||
checksum = "ece97ea872ece730aed82664c424eb4c8291e1ff2480247ccf7409044bc6479f"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
@ -1731,12 +1796,15 @@ version = "0.32.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d68d47bba83f9e2006d117a9a33af1524e655516b8919caac694427a6fb1e511"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"approx 0.5.1",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"serde",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
@ -1762,6 +1830,39 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.15.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
|
||||
dependencies = [
|
||||
"approx 0.4.0",
|
||||
"cblas-sys",
|
||||
"libc",
|
||||
"matrixmultiply",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rawpointer",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray-linalg"
|
||||
version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b0e8dda0c941b64a85c5deb2b3e0144aca87aced64678adfc23eacea6d2cc42"
|
||||
dependencies = [
|
||||
"cauchy",
|
||||
"katexit",
|
||||
"lax",
|
||||
"ndarray",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.7.0"
|
||||
@ -1852,6 +1953,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2218,6 +2321,16 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_xoshiro"
|
||||
version = "0.6.0"
|
||||
@ -2325,9 +2438,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.37.18"
|
||||
version = "0.37.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bbfc1d1c7c40c01715f47d71444744a81669ca84e8b63e25a55e169b1f86433"
|
||||
checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
@ -2469,7 +2582,7 @@ version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"approx 0.5.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
@ -2496,6 +2609,8 @@ dependencies = [
|
||||
"nalgebra",
|
||||
"nalgebra-sparse",
|
||||
"nohash-hasher",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -7,14 +7,16 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
argmin = { version = "0.8.1", features = [] }
|
||||
argmin-math = { version = "0.3.0", features = ["nalgebra_0_32"] }
|
||||
argmin-math = { version = "0.3.0", features = ["latest_all"] }
|
||||
bevy_ecs = "0.10.1"
|
||||
eframe = { version = "0.21.3", features = [] }
|
||||
eframe = { version = "0.21.3" }
|
||||
indexmap = "1.8.1"
|
||||
levenberg-marquardt = "0.13.0"
|
||||
nalgebra = "0.32.0"
|
||||
nalgebra = { version = "0.32.0", features = ["rand"] }
|
||||
nalgebra-sparse = "0.9.0"
|
||||
nohash-hasher = "0.2.0"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.4.0"
|
||||
|
123
src/geometry.rs
123
src/geometry.rs
@ -1,11 +1,11 @@
|
||||
mod var;
|
||||
|
||||
use std::ops::DerefMut;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bevy_ecs::{
|
||||
entity::Entity,
|
||||
prelude::Component,
|
||||
query::QueryEntityError,
|
||||
prelude::{Bundle, Component},
|
||||
query::{QueryEntityError, With},
|
||||
system::{Commands, Query, SystemParam},
|
||||
};
|
||||
|
||||
@ -13,7 +13,7 @@ pub use self::var::*;
|
||||
|
||||
pub type Scalar = f64;
|
||||
|
||||
#[derive(Clone, Debug, Component)]
|
||||
#[derive(Default, Clone, Debug, Component)]
|
||||
pub struct Point<V = VarId> {
|
||||
pub x: V,
|
||||
pub y: V,
|
||||
@ -34,7 +34,7 @@ impl<V> From<(V, V)> for Point<V> {
|
||||
pub type PointPos = Point<Scalar>;
|
||||
pub type PointId = Entity;
|
||||
|
||||
#[derive(Clone, Debug, Component)]
|
||||
#[derive(Clone, Debug, Default, Component)]
|
||||
// TODO: segment, ray and full line
|
||||
// TODO: what if start and end are coincident?
|
||||
pub struct Line<P = PointId> {
|
||||
@ -60,35 +60,45 @@ pub struct Circle {
|
||||
|
||||
// pub type CirclePos = CircleEntity<PointPos>;
|
||||
|
||||
#[derive(Bundle)]
|
||||
pub struct PointBundle {
|
||||
pub point: Point,
|
||||
pub computed_pos: ComputedPointPos,
|
||||
}
|
||||
|
||||
impl PointBundle {
|
||||
pub fn new(x: VarId, y: VarId) -> Self {
|
||||
Self {
|
||||
point: Point::new(x, y),
|
||||
computed_pos: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert_point_at(commands: &mut Commands, point: impl Into<PointPos>) -> PointId {
|
||||
let point = point.into();
|
||||
let x = commands.spawn(Var::new_free(point.x)).id();
|
||||
let y = commands.spawn(Var::new_free(point.y)).id();
|
||||
commands.spawn(Point::new(x, y)).id()
|
||||
commands.spawn(PointBundle::new(x, y)).id()
|
||||
}
|
||||
|
||||
// TODO: figure out generic for this
|
||||
#[derive(SystemParam)]
|
||||
pub struct PointPosQuery<'w, 's> {
|
||||
vars: Query<'w, 's, &'static Var>,
|
||||
#[derive(Component, Default)]
|
||||
pub struct ComputedPointPos(Var<PointPos>);
|
||||
|
||||
impl Deref for ComputedPointPos {
|
||||
type Target = Var<PointPos>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// type PointPosQuery<'w, 's> = PointPosQueryImpl<'w, 's, &'static Var>;
|
||||
// type PointPosQueryMut<'w, 's> = PointPosQueryImpl<'w, 's, &'static mut Var>;
|
||||
|
||||
// impl<'w, 's, Q> PointPosQueryImpl<'w, 's, Q>
|
||||
// where
|
||||
// Q: WorldQuery + 'static,
|
||||
// for<'w2, 's2> <<Q as WorldQuery>::ReadOnlyFetch as Fetch<'w2, 's2>>::Item: Deref<Target = Var>,
|
||||
impl<'w, 's> PointPosQuery<'w, 's> {
|
||||
pub fn try_get(&self, point: &Point) -> Result<Var<PointPos>, QueryEntityError> {
|
||||
let x = self.vars.get(point.x)?;
|
||||
let y = self.vars.get(point.y)?;
|
||||
Ok(x.merge(*y, PointPos::new))
|
||||
}
|
||||
pub fn get(&self, point: &Point) -> Var<PointPos> {
|
||||
self.try_get(point).unwrap()
|
||||
}
|
||||
pub fn update_point_pos(mut points: Query<(&Point, &mut ComputedPointPos)>, vars: Query<&Var>) {
|
||||
points.for_each_mut(|(point, mut computed_pos)| {
|
||||
let x = vars.get(point.x).unwrap();
|
||||
let y = vars.get(point.y).unwrap();
|
||||
computed_pos.0 = x.merge(*y, PointPos::new);
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(SystemParam)]
|
||||
@ -116,24 +126,51 @@ impl<'w, 's> PointPosQueryMut<'w, 's>
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(SystemParam)]
|
||||
pub struct LinePosQuery<'w, 's> {
|
||||
point_pos: PointPosQuery<'w, 's>,
|
||||
points: Query<'w, 's, &'static Point>,
|
||||
}
|
||||
#[derive(Default, Component)]
|
||||
pub struct ComputedLinePos(Var<LinePos>);
|
||||
|
||||
impl<'w, 's> LinePosQuery<'w, 's> {
|
||||
pub fn get(&self, line: &Line) -> Var<LinePos> {
|
||||
self.try_get(line).unwrap()
|
||||
}
|
||||
impl Deref for ComputedLinePos {
|
||||
type Target = Var<LinePos>;
|
||||
|
||||
pub fn try_get(&self, line: &Line) -> Result<Var<LinePos>, QueryEntityError> {
|
||||
// TODO: error handling?
|
||||
let start = self.points.get(line.start)?;
|
||||
let start = self.point_pos.try_get(start)?;
|
||||
|
||||
let end = self.points.get(line.end)?;
|
||||
let end = self.point_pos.try_get(end)?;
|
||||
Ok(start.merge(end, LinePos::new))
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_line_pos(
|
||||
mut lines: Query<(&Line, &mut ComputedLinePos)>,
|
||||
points: Query<&ComputedPointPos>,
|
||||
) {
|
||||
lines.for_each_mut(|(line, mut computed_pos)| {
|
||||
let Ok(start) = points.get(line.start) else { return };
|
||||
let Ok(end) = points.get(line.end) else { return} ;
|
||||
computed_pos.0 = (*start).clone().merge((*end).clone(), LinePos::new);
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Bundle)]
|
||||
pub struct LineBundle {
|
||||
pub line: Line,
|
||||
pub computed_pos: ComputedLinePos,
|
||||
}
|
||||
|
||||
impl LineBundle {
|
||||
pub fn new(start: PointId, end: PointId) -> Self {
|
||||
Self {
|
||||
line: Line::new(start, end),
|
||||
computed_pos: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_dangling_lines(
|
||||
mut lines: Query<(Entity, &Line)>,
|
||||
points: Query<(), With<Point>>,
|
||||
mut commands: Commands,
|
||||
) {
|
||||
lines.for_each_mut(|(id, line)| {
|
||||
if !points.contains(line.start) || !points.contains(line.end) {
|
||||
commands.entity(id).despawn();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use bevy_ecs::{prelude::Component, entity::Entity};
|
||||
use bevy_ecs::{entity::Entity, prelude::Component};
|
||||
|
||||
use super::Scalar;
|
||||
|
||||
@ -28,6 +28,12 @@ pub struct Var<T = Scalar> {
|
||||
pub status: VarStatus,
|
||||
}
|
||||
|
||||
impl<T: Default> Default for Var<T> {
|
||||
fn default() -> Self {
|
||||
Self::new_free(T::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Var<T> {
|
||||
pub fn new_free(value: T) -> Self {
|
||||
Self {
|
||||
|
210
src/main.rs
210
src/main.rs
@ -10,17 +10,12 @@ use eframe::{
|
||||
emath::{Pos2, Rect, RectTransform, Vec2},
|
||||
epaint::{Color32, Hsva, Stroke},
|
||||
};
|
||||
use geometry::{Line, LinePosQuery, Point, PointId, PointPos, PointPosQuery, PointPosQueryMut};
|
||||
use geometry::{Line, LineBundle, Point, PointId, PointPos, PointPosQueryMut};
|
||||
|
||||
mod geometry;
|
||||
mod optimization;
|
||||
pub mod optimization;
|
||||
mod relations;
|
||||
|
||||
fn main() {
|
||||
let options = eframe::NativeOptions::default();
|
||||
eframe::run_native("sketchrs", options, Box::new(|_cc| Box::<MyApp>::default())).unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Resource)]
|
||||
enum Tool {
|
||||
Select,
|
||||
@ -103,14 +98,12 @@ const POINT_RADIUS: f32 = 3.0;
|
||||
fn update_hover_point(
|
||||
response: Res<ResponseRes>,
|
||||
to_screen: Res<ToScreen>,
|
||||
point_pos: PointPosQuery,
|
||||
points: Query<(PointId, &Point)>,
|
||||
points: Query<(PointId, &geometry::ComputedPointPos)>,
|
||||
mut commands: Commands,
|
||||
) {
|
||||
points.for_each(|(id, point)| {
|
||||
points.for_each(|(id, pos)| {
|
||||
let hovered = if let Some(hover_pos) = response.hover_pos() {
|
||||
let pos = point_pos.get(point);
|
||||
let center = to_screen.transform_pos(&pos);
|
||||
let center = to_screen.transform_pos(pos);
|
||||
|
||||
(hover_pos - center).length() < (POINT_RADIUS * 3.)
|
||||
} else {
|
||||
@ -128,13 +121,11 @@ fn update_hover_point(
|
||||
fn update_hover_line(
|
||||
response: Res<ResponseRes>,
|
||||
to_screen: Res<ToScreen>,
|
||||
lines: Query<(Entity, &Line)>,
|
||||
line_pos: LinePosQuery,
|
||||
lines: Query<(Entity, &geometry::ComputedLinePos)>,
|
||||
mut commands: Commands,
|
||||
) {
|
||||
lines.for_each(|(id, line)| {
|
||||
lines.for_each(|(id, pos)| {
|
||||
let hovered = if let Some(hover_pos) = response.hover_pos() {
|
||||
let pos = line_pos.get(line);
|
||||
let points = [
|
||||
to_screen.transform_pos(&pos.start),
|
||||
to_screen.transform_pos(&pos.end),
|
||||
@ -172,14 +163,32 @@ fn select_tool(
|
||||
});
|
||||
|
||||
if response.clicked() {
|
||||
selected.for_each(|selected| {
|
||||
commands.entity(selected).remove::<Selected>();
|
||||
});
|
||||
if !response.ctx.input(|input| input.modifiers.shift) {
|
||||
selected.for_each(|selected| {
|
||||
commands.entity(selected).remove::<Selected>();
|
||||
});
|
||||
}
|
||||
// TODO: choose which to select
|
||||
if let Some(hovered) = hovered.iter().next() {
|
||||
commands.entity(hovered).insert(Selected);
|
||||
}
|
||||
}
|
||||
|
||||
if response
|
||||
.ctx
|
||||
.input(|input| input.key_pressed(egui::Key::Escape))
|
||||
{
|
||||
selected.for_each(|selected| {
|
||||
commands.entity(selected).remove::<Selected>();
|
||||
});
|
||||
} else if response
|
||||
.ctx
|
||||
.input(|input| input.key_pressed(egui::Key::Delete))
|
||||
{
|
||||
selected.for_each(|selected| {
|
||||
commands.entity(selected).despawn();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@ -259,8 +268,7 @@ fn add_line_tool(
|
||||
to_screen: Res<ToScreen>,
|
||||
painter: ResMut<PainterRes>,
|
||||
hovered: Query<Entity, (With<Hovered>, With<Point>)>,
|
||||
selected: Query<(Entity, &Point), With<Selected>>,
|
||||
point_pos: PointPosQuery,
|
||||
selected: Query<(Entity, &geometry::ComputedPointPos), With<Selected>>,
|
||||
mut commands: Commands,
|
||||
) {
|
||||
let hover_pos = response.hover_pos().unwrap();
|
||||
@ -276,9 +284,8 @@ fn add_line_tool(
|
||||
};
|
||||
commands.entity(point_id).insert(Selected);
|
||||
}
|
||||
(Some((_, start_point)), false) => {
|
||||
let start_point_pos = point_pos.get(start_point);
|
||||
let points = [to_screen.transform_pos(&start_point_pos), hover_pos];
|
||||
(Some((_, start_point_pos)), false) => {
|
||||
let points = [to_screen.transform_pos(start_point_pos), hover_pos];
|
||||
|
||||
let stroke = Stroke::new(2.0, Color32::DARK_GRAY);
|
||||
|
||||
@ -293,7 +300,7 @@ fn add_line_tool(
|
||||
.next()
|
||||
.unwrap_or_else(|| add_point(&mut commands, hover_pos, &to_screen));
|
||||
|
||||
let line = Line::new(start_point_id, end_point);
|
||||
let line = LineBundle::new(start_point_id, end_point);
|
||||
commands.spawn(line);
|
||||
|
||||
selected.for_each(|(selected_id, _)| {
|
||||
@ -340,13 +347,11 @@ fn is_tool_active(tool: Tool) -> impl System<In = (), Out = bool> + ReadOnlySyst
|
||||
fn paint_lines(
|
||||
to_screen: Res<ToScreen>,
|
||||
painter: ResMut<PainterRes>,
|
||||
lines: Query<(Entity, &Line)>,
|
||||
lines: Query<(Entity, &geometry::ComputedLinePos)>,
|
||||
hovered: Query<(), With<Hovered>>,
|
||||
selected: Query<(), With<Selected>>,
|
||||
line_pos: LinePosQuery,
|
||||
) {
|
||||
lines.for_each(|(id, line)| {
|
||||
let pos = line_pos.get(line);
|
||||
lines.for_each(|(id, pos)| {
|
||||
let points = [
|
||||
to_screen.transform_pos(&pos.start),
|
||||
to_screen.transform_pos(&pos.end),
|
||||
@ -367,15 +372,12 @@ fn paint_lines(
|
||||
fn paint_points(
|
||||
to_screen: Res<ToScreen>,
|
||||
painter: ResMut<PainterRes>,
|
||||
points: Query<(Entity, &Point)>,
|
||||
points: Query<(Entity, &geometry::ComputedPointPos)>,
|
||||
hovered: Query<(), With<Hovered>>,
|
||||
selected: Query<(), With<Selected>>,
|
||||
point_pos: PointPosQuery,
|
||||
) {
|
||||
points.for_each(|(id, point)| {
|
||||
let pos = point_pos.get(point);
|
||||
|
||||
let center = to_screen.transform_pos(&pos);
|
||||
points.for_each(|(id, pos)| {
|
||||
let center = to_screen.transform_pos(pos);
|
||||
|
||||
let color = color_for_var_status(pos.status);
|
||||
let stroke = if selected.contains(id) || hovered.contains(id) {
|
||||
@ -392,7 +394,7 @@ fn init(mut commands: Commands) {
|
||||
let p1 = geometry::insert_point_at(&mut commands, (10., 30.));
|
||||
let p2 = geometry::insert_point_at(&mut commands, (-20., 15.));
|
||||
geometry::insert_point_at(&mut commands, (0., -10.));
|
||||
commands.spawn(Line::new(p1, p2));
|
||||
commands.spawn(LineBundle::new(p1, p2));
|
||||
}
|
||||
|
||||
#[derive(Resource)]
|
||||
@ -406,22 +408,6 @@ impl Deref for ContextRes {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[system_set(base)]
|
||||
pub enum ShowEntitiesStage {
|
||||
Input,
|
||||
Tools,
|
||||
Paint,
|
||||
}
|
||||
|
||||
#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[system_set(base)]
|
||||
pub enum ScheduleSet {
|
||||
Input,
|
||||
Tools,
|
||||
Paint,
|
||||
}
|
||||
|
||||
fn prepare(ctx: Res<ContextRes>) {
|
||||
ctx.request_repaint();
|
||||
ctx.set_visuals(egui::Visuals::dark());
|
||||
@ -454,13 +440,36 @@ fn toolbar(
|
||||
|
||||
struct ShowEntitiesSchedule(Schedule);
|
||||
|
||||
#[derive(SystemSet, Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[system_set(base)]
|
||||
pub enum ShowEntitiesStage {
|
||||
Update,
|
||||
Input,
|
||||
Tools,
|
||||
PostTools,
|
||||
Paint,
|
||||
}
|
||||
|
||||
impl Default for ShowEntitiesSchedule {
|
||||
fn default() -> Self {
|
||||
let mut schedule = Schedule::new();
|
||||
schedule
|
||||
.configure_sets((ScheduleSet::Input, ScheduleSet::Tools, ScheduleSet::Paint).chain());
|
||||
schedule
|
||||
.add_systems((update_hover_point, update_hover_line).in_base_set(ScheduleSet::Input));
|
||||
schedule.configure_sets(
|
||||
(
|
||||
ShowEntitiesStage::Update,
|
||||
ShowEntitiesStage::Input,
|
||||
ShowEntitiesStage::Tools,
|
||||
ShowEntitiesStage::PostTools,
|
||||
ShowEntitiesStage::Paint,
|
||||
)
|
||||
.chain(),
|
||||
);
|
||||
schedule.add_systems(
|
||||
(geometry::update_point_pos, geometry::update_line_pos)
|
||||
.in_base_set(ShowEntitiesStage::Update),
|
||||
);
|
||||
schedule.add_systems(
|
||||
(update_hover_point, update_hover_line).in_base_set(ShowEntitiesStage::Input),
|
||||
);
|
||||
schedule.add_systems(
|
||||
(
|
||||
select_tool.run_if(is_tool_active(Tool::Select)),
|
||||
@ -470,10 +479,14 @@ impl Default for ShowEntitiesSchedule {
|
||||
add_relation_tool.run_if(is_tool_active(Tool::AddRelation)),
|
||||
)
|
||||
.distributive_run_if(is_hovered)
|
||||
.in_base_set(ScheduleSet::Tools),
|
||||
.in_base_set(ShowEntitiesStage::Tools),
|
||||
);
|
||||
schedule
|
||||
.add_system(geometry::remove_dangling_lines.in_base_set(ShowEntitiesStage::PostTools));
|
||||
schedule.add_systems(
|
||||
(paint_lines, paint_points.after(paint_lines)).in_base_set(ScheduleSet::Paint),
|
||||
(paint_lines, paint_points)
|
||||
.chain()
|
||||
.in_base_set(ShowEntitiesStage::Paint),
|
||||
);
|
||||
Self(schedule)
|
||||
}
|
||||
@ -513,6 +526,53 @@ struct SelectableEntity<'a> {
|
||||
line: Option<&'a Line>,
|
||||
}
|
||||
|
||||
fn side_panel_ui(
|
||||
ui: &mut egui::Ui,
|
||||
selected: Query<SelectableEntity, With<Selected>>,
|
||||
tool: Res<Tool>,
|
||||
) {
|
||||
let tool = *tool;
|
||||
ui.vertical(|ui| match tool {
|
||||
Tool::Select => {
|
||||
let mut count = 0;
|
||||
selected.for_each(|sel| {
|
||||
count += 1;
|
||||
if sel.point.is_some() {
|
||||
ui.label(format!("Selected point {}", sel.id.index()));
|
||||
} else if sel.line.is_some() {
|
||||
ui.label(format!("Selected line {}", sel.id.index()));
|
||||
}
|
||||
});
|
||||
if count == 0 {
|
||||
ui.label("Nothing selected");
|
||||
}
|
||||
}
|
||||
Tool::Move => {
|
||||
let mut count = 0;
|
||||
selected.for_each(|sel| {
|
||||
count += 1;
|
||||
if sel.point.is_some() {
|
||||
ui.label(format!("Selected point {}", sel.id.index()));
|
||||
} else if sel.line.is_some() {
|
||||
ui.label(format!("Selected line {}", sel.id.index()));
|
||||
}
|
||||
});
|
||||
if count == 0 {
|
||||
ui.label("Nothing selected");
|
||||
}
|
||||
}
|
||||
Tool::AddPoint => {
|
||||
ui.label("Click to add a point");
|
||||
}
|
||||
Tool::AddLine => {
|
||||
ui.label("Click to add a line");
|
||||
}
|
||||
Tool::AddRelation => {
|
||||
ui.label("Click to add a relation");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn side_panel(
|
||||
ctx: Res<ContextRes>,
|
||||
selected: Query<SelectableEntity, With<Selected>>,
|
||||
@ -522,34 +582,7 @@ fn side_panel(
|
||||
.resizable(true)
|
||||
.default_width(150.0)
|
||||
.width_range(80.0..=200.0)
|
||||
.show(&ctx, |ui| {
|
||||
let tool = *tool;
|
||||
ui.vertical(|ui| match tool {
|
||||
Tool::Select => {
|
||||
let mut count = 0;
|
||||
selected.for_each(|sel| {
|
||||
count += 1;
|
||||
if sel.point.is_some() {
|
||||
ui.label(format!("Selected point {}", sel.id.index()));
|
||||
} else if sel.line.is_some() {
|
||||
ui.label(format!("Selected line {}", sel.id.index()));
|
||||
}
|
||||
});
|
||||
if count == 0 {
|
||||
ui.label("No selection");
|
||||
}
|
||||
}
|
||||
Tool::AddRelation => {
|
||||
if let Some(_first) = selected.iter().next() {
|
||||
} else {
|
||||
ui.label("Select an entity to add a relation");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
ui.label(":)");
|
||||
}
|
||||
});
|
||||
});
|
||||
.show(&ctx, |ui| side_panel_ui(ui, selected, tool));
|
||||
}
|
||||
|
||||
fn bottom_panel(ctx: Res<ContextRes>) {
|
||||
@ -599,3 +632,8 @@ impl eframe::App for MyApp {
|
||||
self.world.remove_resource::<ContextRes>();
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let options = eframe::NativeOptions::default();
|
||||
eframe::run_native("sketchrs", options, Box::new(|_cc| Box::<MyApp>::default())).unwrap();
|
||||
}
|
||||
|
@ -1,122 +1,449 @@
|
||||
use std::mem::MaybeUninit;
|
||||
use std::{borrow::Borrow, fmt::Debug};
|
||||
|
||||
use nalgebra::{DMatrix, DVector, DimName};
|
||||
use argmin_math;
|
||||
use levenberg_marquardt::{self, TerminationReason};
|
||||
use nalgebra::{ArrayStorage, Const, Dyn, VecStorage};
|
||||
use nalgebra_sparse::CooMatrix;
|
||||
|
||||
use crate::geometry::Scalar;
|
||||
|
||||
type SVector<const D: usize> = nalgebra::SVector<Scalar, D>;
|
||||
type SMatrix<const R: usize, const C: usize> = nalgebra::SMatrix<Scalar, R, C>;
|
||||
type DVector = nalgebra::DVector<Scalar>;
|
||||
type DMatrix = nalgebra::DMatrix<Scalar>;
|
||||
|
||||
trait SResidual<const N: usize> {
|
||||
fn apply(&self, x: SVector<N>) -> Scalar;
|
||||
fn grad(&self, x: SVector<N>) -> SVector<N>;
|
||||
trait SResBlock<const NIN: usize, const NOUT: usize = 1>: Clone + Debug {
|
||||
fn evaluate(
|
||||
&self,
|
||||
x: SVector<NIN>,
|
||||
residual: Option<&mut SVector<NOUT>>,
|
||||
jacobian: Option<&mut SMatrix<NIN, NOUT>>,
|
||||
) {
|
||||
if let Some(residual) = residual {
|
||||
*residual = self.residual(x);
|
||||
}
|
||||
if let Some(jacobian) = jacobian {
|
||||
*jacobian = self.jacobian(x);
|
||||
}
|
||||
}
|
||||
|
||||
fn residual(&self, x: SVector<NIN>) -> SVector<NOUT> {
|
||||
let mut residual = SVector::<NOUT>::zeros();
|
||||
self.evaluate(x, Some(&mut residual), None);
|
||||
residual
|
||||
}
|
||||
fn jacobian(&self, x: SVector<NIN>) -> SMatrix<NIN, NOUT> {
|
||||
let mut jacobian = SMatrix::<NIN, NOUT>::zeros();
|
||||
self.evaluate(x, None, Some(&mut jacobian));
|
||||
jacobian
|
||||
}
|
||||
|
||||
// Probably should move to own trait
|
||||
fn into_indexed(
|
||||
self,
|
||||
in_indices: [Index; NIN],
|
||||
out_indices: [Index; NOUT],
|
||||
) -> Box<dyn IndexedResBlock>
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
{
|
||||
Box::new(IndexedResBlockImpl {
|
||||
residual: self,
|
||||
in_indices,
|
||||
out_indices,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct PointPointDistance(Scalar);
|
||||
|
||||
fn square(x: Scalar) -> Scalar {
|
||||
x * x
|
||||
}
|
||||
|
||||
impl SResidual<4> for PointPointDistance {
|
||||
fn apply(&self, x: SVector<4>) -> Scalar {
|
||||
let [[x1, y1, x2, y2]] = x.data.0;
|
||||
// x.generic_slice()
|
||||
square(x2 - x1) + square(y2 - y1) - square(self.0)
|
||||
// (square(x2 - x1) + square(y2 - y1)).sqrt() - self.0
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Equal1D;
|
||||
|
||||
fn grad(&self, x: SVector<4>) -> SVector<4> {
|
||||
let [[x1, y1, x2, y2]] = x.data.0;
|
||||
SVector::<4>::new(
|
||||
2. * (x1 - x2),
|
||||
2. * (y1 - y2),
|
||||
2. * (x2 - x1),
|
||||
2. * (y2 - x1),
|
||||
)
|
||||
impl SResBlock<2> for Equal1D {
|
||||
fn evaluate(
|
||||
&self,
|
||||
x: SVector<2>,
|
||||
residual: Option<&mut SVector<1>>,
|
||||
jacobian: Option<&mut SMatrix<2, 1>>,
|
||||
) {
|
||||
let [x1, x2] = *x.as_ref();
|
||||
let dx = x2 - x1;
|
||||
if let Some(residual) = residual {
|
||||
residual[0] = square(dx);
|
||||
}
|
||||
if let Some(jacobian) = jacobian {
|
||||
jacobian[0] = -2. * dx;
|
||||
jacobian[1] = 2. * dx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait IndexedResidual {
|
||||
fn apply(&self, x: &[Scalar]) -> Scalar;
|
||||
fn grad_indices(&self) -> &[usize];
|
||||
fn grad_append(&self, x: &[Scalar], grad: &mut Vec<Scalar>);
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AxisDistance(Scalar);
|
||||
|
||||
impl SResBlock<2> for AxisDistance {
|
||||
fn evaluate(
|
||||
&self,
|
||||
x: SVector<2>,
|
||||
residual: Option<&mut SVector<1>>,
|
||||
jacobian: Option<&mut SMatrix<2, 1>>,
|
||||
) {
|
||||
let [x1, x2] = *x.as_ref();
|
||||
let dx = x2 - x1;
|
||||
let r = self.0;
|
||||
if let Some(residual) = residual {
|
||||
residual[0] = square(dx) - square(r);
|
||||
}
|
||||
if let Some(jacobian) = jacobian {
|
||||
jacobian[0] = -2. * dx;
|
||||
jacobian[1] = 2. * dx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MappedResidual<const N: usize, R> {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PointPointDistance(Scalar);
|
||||
|
||||
impl SResBlock<4> for PointPointDistance {
|
||||
fn evaluate(
|
||||
&self,
|
||||
x: SVector<4>,
|
||||
residual: Option<&mut SVector<1>>,
|
||||
jacobian: Option<&mut SMatrix<4, 1>>,
|
||||
) {
|
||||
let [x1, y1, x2, y2] = *x.as_ref();
|
||||
let (dx, dy) = (x2 - x1, y2 - y1);
|
||||
let r = self.0;
|
||||
if let Some(residual) = residual {
|
||||
residual[0] = square(dx) + square(dy) - square(r);
|
||||
}
|
||||
if let Some(jacobian) = jacobian {
|
||||
jacobian[0] = -2. * dx;
|
||||
jacobian[1] = -2. * dy;
|
||||
jacobian[2] = 2. * dx;
|
||||
jacobian[3] = 2. * dy;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PointDistance(SVector<2>, Scalar);
|
||||
|
||||
impl SResBlock<2> for PointDistance {
|
||||
fn evaluate(
|
||||
&self,
|
||||
x: SVector<2>,
|
||||
residual: Option<&mut SVector<1>>,
|
||||
jacobian: Option<&mut SMatrix<2, 1>>,
|
||||
) {
|
||||
let d = x - self.0;
|
||||
let r = self.1;
|
||||
if let Some(residual) = residual {
|
||||
residual[0] = d.norm_squared() - square(r);
|
||||
}
|
||||
if let Some(jacobian) = jacobian {
|
||||
jacobian.copy_from(&(2. * d));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Collinear;
|
||||
|
||||
impl SResBlock<6> for Collinear {
|
||||
fn evaluate(
|
||||
&self,
|
||||
x: SVector<6>,
|
||||
residual: Option<&mut SVector<1>>,
|
||||
jacobian: Option<&mut SMatrix<6, 1>>,
|
||||
) {
|
||||
let [x1, y1, x2, y2, x3, y3] = *x.as_ref();
|
||||
let (dx21, dy21) = (x2 - x1, y2 - y1);
|
||||
let (dx32, dy32) = (x3 - x2, y3 - y2);
|
||||
if let Some(residual) = residual {
|
||||
// dy21 / dx21 = dy32 / dx32
|
||||
// dy21 * dx32 = dy32 * dx21
|
||||
residual[0] = dy21 * dx32 - dy32 * dx21;
|
||||
}
|
||||
if let Some(jacobian) = jacobian {
|
||||
jacobian[0] = dy32;
|
||||
jacobian[1] = -dx32;
|
||||
jacobian[2] = -dy32;
|
||||
jacobian[3] = -dx21;
|
||||
jacobian[4] = dy21;
|
||||
jacobian[5] = -dx21;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Index = u32;
|
||||
|
||||
trait ScatterResBlock: Debug {
|
||||
fn shape(&self) -> (usize, usize);
|
||||
fn apply_into(&self, x: &[Scalar], residuals: &mut [Scalar]);
|
||||
fn jacobian_into(&self, x: &[Scalar], jacobian: &mut [Scalar]);
|
||||
fn evaluate_into(&self, x: &[Scalar], residuals: &mut [Scalar], jacobian: &mut [Scalar]) {
|
||||
self.apply_into(x, residuals);
|
||||
self.jacobian_into(x, jacobian);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ScatterResBlockImpl<const NIN: usize, const NOUT: usize, R> {
|
||||
residual: R,
|
||||
indices: [usize; N],
|
||||
in_indices: [Index; NIN],
|
||||
residual_indices: [Index; NOUT],
|
||||
jacobian_indices: [[Index; NIN]; NOUT],
|
||||
}
|
||||
|
||||
impl<const N: usize, R> MappedResidual<N, R> {
|
||||
fn gather_svector(&self, x: &[Scalar]) -> SVector<N> {
|
||||
SVector::from_fn(|_, i| x[self.indices[i]])
|
||||
impl<const NIN: usize, const NOUT: usize, R> ScatterResBlockImpl<NIN, NOUT, R> {
|
||||
fn gather_svector(&self, x: &[Scalar]) -> SVector<NIN> {
|
||||
SVector::from_fn(|i, _| x[self.in_indices[i] as usize])
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, R: SResidual<N>> IndexedResidual for MappedResidual<N, R> {
|
||||
fn apply(&self, x: &[Scalar]) -> Scalar {
|
||||
self.residual.apply(self.gather_svector(x))
|
||||
impl<const NIN: usize, const NOUT: usize, R> ScatterResBlock for ScatterResBlockImpl<NIN, NOUT, R>
|
||||
where
|
||||
R: SResBlock<NIN, NOUT>,
|
||||
{
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(NIN, NOUT)
|
||||
}
|
||||
|
||||
fn grad_indices(&self) -> &[usize] {
|
||||
&self.indices
|
||||
fn apply_into(&self, x: &[Scalar], residuals: &mut [Scalar]) {
|
||||
// Front load bounds check
|
||||
let mut inbounds = true;
|
||||
for i in self.residual_indices.iter().cloned() {
|
||||
inbounds &= (i as usize) < residuals.len();
|
||||
}
|
||||
assert!(inbounds);
|
||||
|
||||
let res = self.residual.residual(self.gather_svector(x));
|
||||
for (i, r) in self
|
||||
.residual_indices
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(res.borrow().iter())
|
||||
{
|
||||
residuals[i as usize] = *r;
|
||||
}
|
||||
}
|
||||
|
||||
fn grad_append(&self, x: &[Scalar], mut grad: &mut Vec<Scalar>) {
|
||||
let gvals = self.residual.grad(self.gather_svector(x));
|
||||
grad.extend(gvals.as_slice());
|
||||
fn jacobian_into(&self, x: &[Scalar], jacobian: &mut [Scalar]) {
|
||||
// Front load bounds check
|
||||
let mut inbounds = true;
|
||||
for ir in self.jacobian_indices.iter() {
|
||||
for i in ir.iter().cloned() {
|
||||
inbounds &= (i as usize) < jacobian.len();
|
||||
}
|
||||
}
|
||||
assert!(inbounds);
|
||||
|
||||
let jac = self.residual.jacobian(self.gather_svector(x));
|
||||
for (ir, r) in self.jacobian_indices.iter().zip(jac.column_iter()) {
|
||||
for (i, r) in ir.iter().zip(r.iter()) {
|
||||
jacobian[*i as usize] = *r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_into(&self, x: &[Scalar], residuals: &mut [Scalar], jacobian: &mut [Scalar]) {
|
||||
// Front load bounds check
|
||||
let mut inbounds = true;
|
||||
for i in self.residual_indices.iter().cloned() {
|
||||
inbounds &= (i as usize) < residuals.len();
|
||||
}
|
||||
for ir in self.jacobian_indices.iter() {
|
||||
for i in ir.iter().cloned() {
|
||||
inbounds &= (i as usize) < jacobian.len();
|
||||
}
|
||||
}
|
||||
assert!(inbounds);
|
||||
|
||||
let mut res = SVector::<NOUT>::zeros();
|
||||
let mut jac = SMatrix::<NIN, NOUT>::zeros();
|
||||
self.residual
|
||||
.evaluate(self.gather_svector(x), Some(&mut res), Some(&mut jac));
|
||||
for (i, r) in self.residual_indices.iter().cloned().zip(res.iter()) {
|
||||
residuals[i as usize] = *r;
|
||||
}
|
||||
for (ir, r) in self.jacobian_indices.iter().zip(jac.column_iter()) {
|
||||
for (i, r) in ir.iter().zip(r.iter()) {
|
||||
jacobian[*i as usize] = *r;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait IndexedResBlock {
|
||||
fn shape(&self) -> (usize, usize);
|
||||
fn in_indices(&self) -> &[Index];
|
||||
fn out_indices(&self) -> &[Index];
|
||||
fn into_scatter(self: Box<Self>, total_shape: (usize, usize)) -> Box<dyn ScatterResBlock>;
|
||||
}
|
||||
|
||||
struct IndexedResBlockImpl<const NIN: usize, const NOUT: usize, R> {
|
||||
residual: R,
|
||||
in_indices: [Index; NIN],
|
||||
out_indices: [Index; NOUT],
|
||||
}
|
||||
|
||||
impl<const NIN: usize, const NOUT: usize, R> IndexedResBlock for IndexedResBlockImpl<NIN, NOUT, R>
|
||||
where
|
||||
R: SResBlock<NIN, NOUT> + 'static,
|
||||
{
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(NIN, NOUT)
|
||||
}
|
||||
|
||||
fn in_indices(&self) -> &[Index] {
|
||||
&self.in_indices
|
||||
}
|
||||
|
||||
fn out_indices(&self) -> &[Index] {
|
||||
&self.out_indices
|
||||
}
|
||||
|
||||
fn into_scatter(self: Box<Self>, total_shape: (usize, usize)) -> Box<dyn ScatterResBlock> {
|
||||
let mut jacobian_indices = [[0; NIN]; NOUT];
|
||||
for (i, jr) in jacobian_indices.iter_mut().enumerate() {
|
||||
for (j, jj) in jr.iter_mut().enumerate() {
|
||||
// *jj = self.in_indices[j] + (self.out_indices[i] * total_shape.0 as u32) as Index;
|
||||
*jj = self.out_indices[i] + (self.in_indices[j] * total_shape.0 as u32) as Index;
|
||||
}
|
||||
}
|
||||
Box::new(ScatterResBlockImpl {
|
||||
residual: self.residual,
|
||||
in_indices: self.in_indices,
|
||||
residual_indices: self.out_indices,
|
||||
jacobian_indices,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Residuals {
|
||||
residuals: Vec<Box<dyn IndexedResidual>>,
|
||||
res_blocks: Vec<Box<dyn ScatterResBlock>>,
|
||||
shape: (usize, usize),
|
||||
}
|
||||
|
||||
impl Residuals {
|
||||
fn apply_fill(&self, x: &[Scalar], residuals: &mut [Scalar]) {
|
||||
assert_eq!(self.residuals.len(), residuals.len());
|
||||
for (resbox, rval) in self.residuals.iter().zip(residuals) {
|
||||
*rval = resbox.apply(x);
|
||||
fn new(res_blocks: Vec<Box<dyn IndexedResBlock>>, shape: (usize, usize)) -> Self {
|
||||
let res_blocks = res_blocks
|
||||
.into_iter()
|
||||
.map(|r| r.into_scatter(shape))
|
||||
.collect();
|
||||
Self { res_blocks, shape }
|
||||
}
|
||||
|
||||
fn residuals_into<X: AsRef<[Scalar]>>(&self, x: X, residuals: &mut [Scalar]) {
|
||||
let x = x.as_ref();
|
||||
assert!(self.shape.0 == residuals.len());
|
||||
for residual in &self.res_blocks {
|
||||
residual.apply_into(x, residuals);
|
||||
}
|
||||
}
|
||||
|
||||
// fn grad_fill(&self, x: &[Scalar], grad: &mut [Scalar]) {
|
||||
// assert_eq!(self.residuals.len(), grad.len());
|
||||
// let mut gradcol = Vec::new();
|
||||
// for resbox in self.residuals.iter() {
|
||||
// let res = resbox.apply(x);
|
||||
// let indices = resbox.grad_indices();
|
||||
// gradcol.resize(indices.len(), 0.);
|
||||
// resbox.grad_fill(x, grad);
|
||||
// for i in indices {
|
||||
// grad[i] = 2 * res * gradcol
|
||||
// }
|
||||
// // resbox.grad_fill(x, gradcol.as_mut_slice());
|
||||
fn residuals<X: AsRef<[Scalar]>>(&self, x: X) -> Vec<Scalar> {
|
||||
let mut residuals = vec![0.; self.shape.0];
|
||||
self.residuals_into(x, &mut residuals);
|
||||
residuals
|
||||
}
|
||||
|
||||
// fn jacobian_into<'a, X: AsRef<[Scalar]>, J: Into<DMatrixViewMut<'a>>>(
|
||||
// &self,
|
||||
// x: X,
|
||||
// jacobian: J,
|
||||
// ) {
|
||||
// let x = x.as_ref();
|
||||
// let mut jacobian = jacobian.into();
|
||||
// assert!(self.shape == jacobian.shape());
|
||||
// for residual in &self.residuals {
|
||||
// residual.jacobian_into(x, jacobian.as_mut_slice());
|
||||
// }
|
||||
// }
|
||||
|
||||
fn jacobian(&self, x: &[Scalar]) -> CooMatrix<Scalar> {
|
||||
let gradlen = self
|
||||
.residuals
|
||||
.iter()
|
||||
.map(|res| res.grad_indices().len())
|
||||
.sum::<usize>();
|
||||
let mut row_indices = Vec::with_capacity(gradlen);
|
||||
let mut col_indices = Vec::with_capacity(gradlen);
|
||||
let mut grads = Vec::with_capacity(gradlen);
|
||||
for (j, resbox) in self.residuals.iter().enumerate() {
|
||||
row_indices.extend(resbox.grad_indices());
|
||||
col_indices.extend(std::iter::repeat(j).take(resbox.grad_indices().len()));
|
||||
resbox.grad_append(x, &mut grads);
|
||||
fn jacobian<X: AsRef<[Scalar]>>(&self, x: X) -> DMatrix {
|
||||
// assert!(self.shape == jacobian.shape());
|
||||
let x = x.as_ref();
|
||||
let mut jacobian = DMatrix::zeros(self.shape.0, self.shape.1);
|
||||
for residual in &self.res_blocks {
|
||||
residual.jacobian_into(x, jacobian.as_mut_slice());
|
||||
}
|
||||
CooMatrix::try_from_triplets(
|
||||
x.len(),
|
||||
self.residuals.len(),
|
||||
row_indices,
|
||||
col_indices,
|
||||
grads,
|
||||
)
|
||||
.unwrap()
|
||||
jacobian
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Problem {
|
||||
residuals: Residuals,
|
||||
x: DVector,
|
||||
}
|
||||
|
||||
impl levenberg_marquardt::LeastSquaresProblem<Scalar, Dyn, Dyn> for Problem {
|
||||
type ResidualStorage = VecStorage<Scalar, Dyn, Const<1>>;
|
||||
type JacobianStorage = VecStorage<Scalar, Dyn, Dyn>;
|
||||
type ParameterStorage = VecStorage<Scalar, Dyn, Const<1>>;
|
||||
|
||||
fn set_params(&mut self, x: &DVector) {
|
||||
self.x = x.clone();
|
||||
}
|
||||
|
||||
fn params(&self) -> DVector {
|
||||
self.x.clone()
|
||||
}
|
||||
|
||||
fn residuals(&self) -> Option<DVector> {
|
||||
let res = self.residuals.residuals(self.x.as_slice());
|
||||
Some(DVector::from(res))
|
||||
}
|
||||
|
||||
fn jacobian(&self) -> Option<nalgebra::Matrix<Scalar, Dyn, Dyn, Self::JacobianStorage>> {
|
||||
let jac = self.residuals.jacobian(self.x.as_slice());
|
||||
Some(jac)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test() {
|
||||
let lm = levenberg_marquardt::LevenbergMarquardt::<f64>::new();
|
||||
|
||||
let nvar = 4;
|
||||
let nres = 3;
|
||||
let residuals = Residuals::new(
|
||||
vec![
|
||||
PointDistance(SVector::<2>::new(1., 0.), 0.75).into_indexed([0, 1], [0]),
|
||||
PointPointDistance(1.0).into_indexed([0, 1, 2, 3], [2]),
|
||||
PointDistance(SVector::<2>::new(0., 0.), 1.25).into_indexed([2, 3], [1]),
|
||||
],
|
||||
(nres, nvar),
|
||||
);
|
||||
|
||||
let dist = rand_distr::Normal::new(0., 1.0).unwrap();
|
||||
let prob = Problem {
|
||||
residuals,
|
||||
x: DVector::from_distribution_generic(
|
||||
Dyn(nvar),
|
||||
Const::<1>,
|
||||
&dist,
|
||||
&mut rand::thread_rng(),
|
||||
),
|
||||
};
|
||||
|
||||
println!("Problem: {:?}", prob);
|
||||
|
||||
let (mut prob, rep) = lm.minimize(prob);
|
||||
println!("Minimized: {:?}", prob.x);
|
||||
println!("Report: {:?}", rep);
|
||||
|
||||
// // perturb the initial guess
|
||||
// let dist = rand_distr::Normal::new(0., 1.0).unwrap();
|
||||
// prob.x += SVector::<2>::from_distribution(&dist, &mut rand::thread_rng());
|
||||
// println!("Problem: {:?}", prob);
|
||||
|
||||
// let (mut prob, rep) = lm.minimize(prob);
|
||||
// println!("Minimized: {:?}", prob.x);
|
||||
// println!("Report: {:?}", rep);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user