Initial commit
This commit is contained in:
commit
73bc64a625
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/target
|
2433
Cargo.lock
generated
Normal file
2433
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
19
Cargo.toml
Normal file
19
Cargo.toml
Normal file
@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "sketchrs"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
argmin = { version = "0.5.1", features = ["nalgebra"] }
|
||||
eframe = { version = "0.18.0", features = [] }
|
||||
indexmap = "1.8.1"
|
||||
levenberg-marquardt = "0.12.0"
|
||||
nalgebra = "0.30.1"
|
||||
nalgebra-sparse = "0.6.0"
|
||||
nohash-hasher = "0.2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.5"
|
||||
|
131
src/geometry.rs
Normal file
131
src/geometry.rs
Normal file
@ -0,0 +1,131 @@
|
||||
mod entity;
|
||||
mod var;
|
||||
|
||||
use self::entity::{EntityId, EntityMap};
|
||||
pub use self::var::*;
|
||||
|
||||
pub type Scalar = f64;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PointEntity<V = VarId> {
|
||||
pub x: V,
|
||||
pub y: V,
|
||||
}
|
||||
|
||||
impl<V> PointEntity<V> {
|
||||
pub fn new(x: V, y: V) -> Self {
|
||||
Self { x, y }
|
||||
}
|
||||
}
|
||||
|
||||
impl<V> From<(V, V)> for PointEntity<V> {
|
||||
fn from(p: (V, V)) -> Self {
|
||||
Self { x: p.0, y: p.1 }
|
||||
}
|
||||
}
|
||||
|
||||
pub type PointPos = PointEntity<Scalar>;
|
||||
pub type PointId = EntityId<PointEntity>;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
// TODO: segment, ray and full line
|
||||
// TODO: what if start and end are coincident?
|
||||
pub struct LineEntity<P = PointId> {
|
||||
pub start: P,
|
||||
pub end: P,
|
||||
}
|
||||
|
||||
impl<P> LineEntity<P> {
|
||||
pub fn new(start: P, end: P) -> Self {
|
||||
Self { start, end }
|
||||
}
|
||||
}
|
||||
|
||||
pub type LinePos = LineEntity<PointPos>;
|
||||
pub type LineId = EntityId<LineEntity>;
|
||||
|
||||
#[derive(Clone)]
|
||||
// TODO: arc. how to represent start and end of arc?
|
||||
// angles, unit vectors, or a point for each?
|
||||
pub struct CircleEntity {
|
||||
pub center: PointId,
|
||||
pub radius: VarId,
|
||||
}
|
||||
|
||||
// pub type CirclePos = CircleEntity<PointPos>;
|
||||
pub type CircleId = EntityId<CircleEntity>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct SketchEntities {
|
||||
vars: EntityMap<Var>,
|
||||
points: EntityMap<PointEntity>,
|
||||
lines: EntityMap<LineEntity>,
|
||||
circles: EntityMap<CircleEntity>,
|
||||
}
|
||||
|
||||
impl SketchEntities {
|
||||
pub fn vars(&self) -> &EntityMap<Var> {
|
||||
&self.vars
|
||||
}
|
||||
|
||||
pub fn insert_var(&mut self, var: Var) -> VarId {
|
||||
self.vars.insert(var)
|
||||
}
|
||||
|
||||
pub fn get_var_mut(&mut self, id: VarId) -> Option<&mut Var> {
|
||||
self.vars.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn points(&self) -> &EntityMap<PointEntity> {
|
||||
&self.points
|
||||
}
|
||||
|
||||
pub fn insert_point(&mut self, point: PointEntity) -> PointId {
|
||||
assert!(self.vars.contains(point.x));
|
||||
assert!(self.vars.contains(point.y));
|
||||
self.points.insert(point)
|
||||
}
|
||||
|
||||
pub fn insert_point_at(&mut self, point: impl Into<PointPos>) -> PointId {
|
||||
let point = point.into();
|
||||
let x = self.insert_var(Var::new_free(point.x));
|
||||
let y = self.insert_var(Var::new_free(point.y));
|
||||
self.insert_point(PointEntity::new(x, y))
|
||||
}
|
||||
|
||||
fn remove_point(&mut self, point_id: PointId) -> bool {
|
||||
// TODO: check if used by any line
|
||||
self.points.remove(point_id)
|
||||
}
|
||||
|
||||
pub fn point_pos(&self, point: &PointEntity) -> Var<PointPos> {
|
||||
// TODO: error handling?
|
||||
let x = self.vars.get(point.x).unwrap();
|
||||
let y = self.vars.get(point.y).unwrap();
|
||||
x.merge(*y, PointPos::new)
|
||||
}
|
||||
|
||||
pub fn lines(&self) -> &EntityMap<LineEntity> {
|
||||
&self.lines
|
||||
}
|
||||
|
||||
pub fn insert_line(&mut self, line: LineEntity) -> LineId {
|
||||
assert!(self.points.contains(line.start));
|
||||
assert!(self.points.contains(line.end));
|
||||
self.lines.insert(line)
|
||||
}
|
||||
|
||||
fn remove_line(&mut self, line_id: LineId) -> bool {
|
||||
self.lines.remove(line_id)
|
||||
}
|
||||
|
||||
pub fn line_pos(&self, line: &LineEntity) -> Var<LinePos> {
|
||||
// TODO: error handling?
|
||||
let start = self.points.get(line.start).unwrap();
|
||||
let start = self.point_pos(start);
|
||||
|
||||
let end = self.points.get(line.end).unwrap();
|
||||
let end = self.point_pos(end);
|
||||
start.merge(end, LinePos::new)
|
||||
}
|
||||
}
|
113
src/geometry/entity.rs
Normal file
113
src/geometry/entity.rs
Normal file
@ -0,0 +1,113 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
|
||||
// HashMap with fast no-op hash function and preserves insertion order
|
||||
type IntIndexMap<K, V> = IndexMap<K, V, nohash_hasher::BuildNoHashHasher<K>>;
|
||||
|
||||
pub struct EntityId<T>(u32, PhantomData<T>);
|
||||
|
||||
impl<T> Default for EntityId<T> {
|
||||
fn default() -> Self {
|
||||
Self(1, Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> nohash_hasher::IsEnabled for EntityId<T> {}
|
||||
|
||||
impl<T> EntityId<T> {
|
||||
fn take_next(&mut self) -> Self {
|
||||
let id = *self;
|
||||
self.0 += 1;
|
||||
id
|
||||
}
|
||||
|
||||
fn invalid() -> Self {
|
||||
Self(u32::MAX, Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EntityMap<T> {
|
||||
map: IntIndexMap<EntityId<T>, T>,
|
||||
next_id: EntityId<T>,
|
||||
}
|
||||
|
||||
impl<T> EntityMap<T> {
|
||||
pub fn insert(&mut self, entity: T) -> EntityId<T> {
|
||||
let id = self.next_id.take_next();
|
||||
self.map.insert(id, entity);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn contains(&self, id: EntityId<T>) -> bool {
|
||||
self.map.contains_key(&id)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: EntityId<T>) -> bool {
|
||||
self.map.remove(&id).is_some()
|
||||
}
|
||||
|
||||
pub fn get(&self, id: EntityId<T>) -> Option<&T> {
|
||||
self.map.get(&id)
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, id: EntityId<T>) -> Option<&mut T> {
|
||||
self.map.get_mut(&id)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (EntityId<T>, &T)> {
|
||||
self.map.iter().map(|(id, v)| (*id, v))
|
||||
}
|
||||
}
|
||||
|
||||
// Need to reimplement all of this without the T bound :/
|
||||
|
||||
impl<T> PartialEq for EntityId<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0 && self.1 == other.1
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Eq for EntityId<T> {}
|
||||
|
||||
impl<T> PartialOrd for EntityId<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
self.0.partial_cmp(&other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Ord for EntityId<T> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.0.cmp(&other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for EntityId<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone(), self.1.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Copy for EntityId<T> {}
|
||||
|
||||
impl<T> std::hash::Hash for EntityId<T> {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.0.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::fmt::Debug for EntityId<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("EntityId").field(&self.0).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for EntityMap<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
map: Default::default(),
|
||||
next_id: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
62
src/geometry/var.rs
Normal file
62
src/geometry/var.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use super::{EntityId, Scalar};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum VarStatus {
|
||||
Free,
|
||||
Dependent,
|
||||
Unique,
|
||||
Overconstrained,
|
||||
}
|
||||
|
||||
impl VarStatus {
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
use VarStatus::*;
|
||||
match (self, other) {
|
||||
(Free, Free) => Free,
|
||||
(Unique, Unique) => Unique,
|
||||
(Overconstrained, _) | (_, Overconstrained) => Overconstrained,
|
||||
_ => Dependent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Var<T = Scalar> {
|
||||
pub value: T,
|
||||
pub status: VarStatus,
|
||||
}
|
||||
|
||||
impl<T> Var<T> {
|
||||
pub fn new_free(value: T) -> Self {
|
||||
Self {
|
||||
value,
|
||||
status: VarStatus::Free,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn merge<U, V, F>(self, other: Var<U>, by: F) -> Var<V>
|
||||
where
|
||||
F: FnOnce(T, U) -> V,
|
||||
{
|
||||
Var {
|
||||
value: by(self.value, other.value),
|
||||
status: self.status.merge(other.status),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for Var<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for Var<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
|
||||
pub type VarId = EntityId<Var>;
|
303
src/main.rs
Normal file
303
src/main.rs
Normal file
@ -0,0 +1,303 @@
|
||||
use eframe::{
|
||||
egui::{self, CursorIcon, Frame, PointerButton, Sense, Ui},
|
||||
emath::{Pos2, Rect, RectTransform, Vec2},
|
||||
epaint::{color::Hsva, Color32, Stroke},
|
||||
};
|
||||
use geometry::{LineEntity, PointId, PointPos, SketchEntities};
|
||||
|
||||
mod geometry;
|
||||
mod optimization;
|
||||
mod relations;
|
||||
|
||||
fn main() {
|
||||
let options = eframe::NativeOptions::default();
|
||||
eframe::run_native(
|
||||
"sketchrs",
|
||||
options,
|
||||
Box::new(|_cc| Box::new(MyApp::default())),
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Tool {
|
||||
Select,
|
||||
Move,
|
||||
AddPoint,
|
||||
AddLine,
|
||||
}
|
||||
|
||||
impl Default for Tool {
|
||||
fn default() -> Self {
|
||||
Tool::Select
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ViewState {
|
||||
tool: Tool,
|
||||
hover_point: Option<PointId>,
|
||||
select_point: Option<PointId>,
|
||||
drag_delta: Vec2,
|
||||
}
|
||||
|
||||
struct MyApp {
|
||||
state: ViewState,
|
||||
entities: SketchEntities,
|
||||
}
|
||||
|
||||
impl Default for MyApp {
|
||||
fn default() -> Self {
|
||||
let mut entities = SketchEntities::default();
|
||||
let p1 = entities.insert_point_at((10., 30.));
|
||||
let p2 = entities.insert_point_at((-20., 15.));
|
||||
entities.insert_point_at((0., -10.));
|
||||
entities.insert_line(LineEntity::new(p1, p2));
|
||||
Self {
|
||||
state: ViewState::default(),
|
||||
entities,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn color_for_var_status(status: geometry::VarStatus) -> Hsva {
|
||||
use geometry::VarStatus::*;
|
||||
match status {
|
||||
Free => Hsva::new(200. / 360., 0.90, 0.80, 1.0),
|
||||
Dependent => todo!(),
|
||||
Unique => todo!(),
|
||||
Overconstrained => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
const POINT_RADIUS: f32 = 3.0;
|
||||
|
||||
impl MyApp {
|
||||
fn show_toolbar(&mut self, ui: &mut Ui) {
|
||||
ui.heading("sketchrs");
|
||||
|
||||
ui.horizontal(|ui| {
|
||||
ui.label("Tool: ");
|
||||
let mut tool = &mut self.state.tool;
|
||||
ui.selectable_value(tool, Tool::Select, "Select");
|
||||
ui.selectable_value(tool, Tool::Move, "Move");
|
||||
ui.selectable_value(tool, Tool::AddPoint, "+ Point");
|
||||
ui.selectable_value(tool, Tool::AddLine, "+ Line");
|
||||
});
|
||||
}
|
||||
|
||||
fn show_entities(&mut self, ui: &mut Ui) {
|
||||
let mut state = &mut self.state;
|
||||
|
||||
let sense = match state.tool {
|
||||
Tool::Move => Sense::drag(),
|
||||
Tool::Select | Tool::AddPoint | Tool::AddLine => Sense::click(),
|
||||
};
|
||||
|
||||
let (response, painter) = ui.allocate_painter(ui.available_size(), sense);
|
||||
let ctx = response.ctx.clone();
|
||||
let to_screen = RectTransform::from_to(
|
||||
Rect::from_center_size(Pos2::ZERO, response.rect.size() / 2.0),
|
||||
response.rect,
|
||||
);
|
||||
|
||||
let transform_pos = |pos: &PointPos| to_screen * Pos2::new(pos.x as f32, pos.y as f32);
|
||||
|
||||
let mut hover_line = None;
|
||||
|
||||
state.hover_point = None;
|
||||
|
||||
if let Some(hover_pos) = response.hover_pos() {
|
||||
for (id, point) in self
|
||||
.entities
|
||||
.points()
|
||||
.iter()
|
||||
.filter(|(id, _)| Some(*id) != state.select_point)
|
||||
{
|
||||
let pos = self.entities.point_pos(point);
|
||||
let center = transform_pos(&*pos);
|
||||
|
||||
if (hover_pos - center).length() < (POINT_RADIUS * 3.) {
|
||||
state.hover_point = Some(id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (id, line) in self.entities.lines().iter() {
|
||||
let pos = self.entities.line_pos(line);
|
||||
let points = [transform_pos(&pos.start), transform_pos(&pos.end)];
|
||||
|
||||
let b = points[1] - points[0];
|
||||
let a = hover_pos - points[0];
|
||||
let p = a.dot(b) / b.dot(b);
|
||||
let perp = a - (p * b);
|
||||
let hovered = ((0.)..=(1.)).contains(&p) && perp.length() < 5.0;
|
||||
|
||||
if hovered {
|
||||
hover_line = Some(id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ctx.output().cursor_icon = match state.tool {
|
||||
Tool::Select => {
|
||||
if state.hover_point.is_some() {
|
||||
CursorIcon::PointingHand
|
||||
} else {
|
||||
CursorIcon::Default
|
||||
}
|
||||
}
|
||||
Tool::Move => {
|
||||
if state.select_point.is_some() {
|
||||
CursorIcon::Grabbing
|
||||
} else if state.hover_point.is_some() {
|
||||
CursorIcon::Grab
|
||||
} else {
|
||||
CursorIcon::Move
|
||||
}
|
||||
}
|
||||
Tool::AddPoint => CursorIcon::None,
|
||||
Tool::AddLine => CursorIcon::None,
|
||||
};
|
||||
|
||||
let mut add_point = || {
|
||||
let point_pos = to_screen.inverse() * hover_pos;
|
||||
let point_pos = (point_pos.x as f64, point_pos.y as f64);
|
||||
self.entities.insert_point_at(point_pos)
|
||||
};
|
||||
|
||||
match state.tool {
|
||||
Tool::Select => {
|
||||
if response.clicked() {
|
||||
state.select_point = state.hover_point;
|
||||
}
|
||||
}
|
||||
Tool::Move => {
|
||||
let drag_started = response.drag_started();
|
||||
if drag_started {
|
||||
state.select_point = state.hover_point;
|
||||
} else if response.drag_released() {
|
||||
state.select_point = None;
|
||||
state.drag_delta = Vec2::ZERO;
|
||||
}
|
||||
|
||||
if let Some(point_id) = state.select_point {
|
||||
let point = self.entities.points().get(point_id).unwrap();
|
||||
if drag_started {
|
||||
let point_pos = self.entities.point_pos(point);
|
||||
state.drag_delta = hover_pos - transform_pos(&*point_pos);
|
||||
}
|
||||
let move_to = to_screen.inverse() * (hover_pos - state.drag_delta);
|
||||
for (id, val) in [(point.x, move_to.x), (point.y, move_to.y)] {
|
||||
self.entities.get_var_mut(id).unwrap().value = val as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
Tool::AddPoint => {
|
||||
if response.clicked() {
|
||||
add_point();
|
||||
} else {
|
||||
painter.circle_filled(hover_pos, POINT_RADIUS, Color32::WHITE);
|
||||
}
|
||||
}
|
||||
Tool::AddLine => {
|
||||
match (state.select_point, response.clicked()) {
|
||||
(Some(start_point_id), true) => {
|
||||
// TODO: add point if no hover point
|
||||
let end_point = state.hover_point.unwrap_or_else(add_point);
|
||||
|
||||
let line = LineEntity::new(start_point_id, end_point);
|
||||
self.entities.insert_line(line);
|
||||
|
||||
state.select_point = None;
|
||||
}
|
||||
(None, true) => {
|
||||
state.select_point = Some(state.hover_point.unwrap_or_else(add_point));
|
||||
}
|
||||
(Some(first_point_id), false) => {
|
||||
let first_point = self.entities.points().get(first_point_id).unwrap();
|
||||
let first_point_pos = self.entities.point_pos(first_point);
|
||||
let points = [transform_pos(&first_point_pos), hover_pos];
|
||||
|
||||
let stroke = Stroke::new(2.0, Color32::DARK_GRAY);
|
||||
|
||||
painter.line_segment(points, stroke);
|
||||
|
||||
painter.circle_filled(hover_pos, POINT_RADIUS, Color32::DARK_GRAY);
|
||||
}
|
||||
(None, false) => {
|
||||
painter.circle_filled(hover_pos, POINT_RADIUS, Color32::DARK_GRAY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (id, line) in self.entities.lines().iter() {
|
||||
let pos = self.entities.line_pos(line);
|
||||
let points = [transform_pos(&pos.start), transform_pos(&pos.end)];
|
||||
|
||||
let mut color = color_for_var_status(pos.status);
|
||||
color.v -= 0.6;
|
||||
if state.hover_point.is_none()
|
||||
&& (state.select_point.is_none() || state.tool != Tool::Move)
|
||||
&& Some(id) == hover_line
|
||||
{
|
||||
color.s -= 0.8;
|
||||
}
|
||||
|
||||
let stroke = Stroke::new(2.0, color);
|
||||
|
||||
painter.line_segment(points, stroke);
|
||||
}
|
||||
|
||||
for (id, point) in self.entities.points().iter() {
|
||||
let pos = self.entities.point_pos(point);
|
||||
|
||||
let center = transform_pos(&*pos);
|
||||
|
||||
let mut color = color_for_var_status(pos.status);
|
||||
let stroke = match (state.select_point, state.hover_point) {
|
||||
(Some(sid), _) | (_, Some(sid)) if id == sid => {
|
||||
// color.s -= 0.8;
|
||||
Stroke::new(1.0, Color32::WHITE)
|
||||
}
|
||||
_ => Stroke::default(),
|
||||
};
|
||||
painter.circle(center, POINT_RADIUS, color, stroke);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl eframe::App for MyApp {
|
||||
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
|
||||
ctx.request_repaint();
|
||||
ctx.set_visuals(egui::Visuals::dark());
|
||||
|
||||
egui::TopBottomPanel::top("top_panel").show(ctx, |ui| self.show_toolbar(ui));
|
||||
|
||||
egui::CentralPanel::default()
|
||||
.frame(egui::Frame::none())
|
||||
.show(ctx, |ui| {
|
||||
self.show_entities(ui);
|
||||
});
|
||||
|
||||
egui::SidePanel::right("right_panel")
|
||||
.resizable(true)
|
||||
.default_width(150.0)
|
||||
.width_range(80.0..=200.0)
|
||||
.show(ctx, |ui| {
|
||||
ui.vertical_centered(|ui| {
|
||||
ui.heading("Left Panel");
|
||||
});
|
||||
// egui::ScrollArea::vertical().show(ui, |ui| {
|
||||
// lorem_ipsum(ui);
|
||||
// });
|
||||
});
|
||||
|
||||
egui::TopBottomPanel::bottom("bottom_panel").show(ctx, |ui| {
|
||||
ui.horizontal(|ui| {
|
||||
ui.label("Status: ");
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
122
src/optimization.rs
Normal file
122
src/optimization.rs
Normal file
@ -0,0 +1,122 @@
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
use nalgebra::{DMatrix, DVector, DimName};
|
||||
use nalgebra_sparse::CooMatrix;
|
||||
|
||||
use crate::geometry::Scalar;
|
||||
|
||||
type SVector<const D: usize> = nalgebra::SVector<Scalar, D>;
|
||||
|
||||
trait SResidual<const N: usize> {
|
||||
fn apply(&self, x: SVector<N>) -> Scalar;
|
||||
fn grad(&self, x: SVector<N>) -> SVector<N>;
|
||||
}
|
||||
|
||||
struct PointPointDistance(Scalar);
|
||||
|
||||
fn square(x: Scalar) -> Scalar {
|
||||
x * x
|
||||
}
|
||||
|
||||
impl SResidual<4> for PointPointDistance {
|
||||
fn apply(&self, x: SVector<4>) -> Scalar {
|
||||
let [[x1, y1, x2, y2]] = x.data.0;
|
||||
// x.generic_slice()
|
||||
square(x2 - x1) + square(y2 - y1) - square(self.0)
|
||||
// (square(x2 - x1) + square(y2 - y1)).sqrt() - self.0
|
||||
}
|
||||
|
||||
fn grad(&self, x: SVector<4>) -> SVector<4> {
|
||||
let [[x1, y1, x2, y2]] = x.data.0;
|
||||
SVector::<4>::new(
|
||||
2. * (x1 - x2),
|
||||
2. * (y1 - y2),
|
||||
2. * (x2 - x1),
|
||||
2. * (y2 - x1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
trait IndexedResidual {
|
||||
fn apply(&self, x: &[Scalar]) -> Scalar;
|
||||
fn grad_indices(&self) -> &[usize];
|
||||
fn grad_append(&self, x: &[Scalar], grad: &mut Vec<Scalar>);
|
||||
}
|
||||
|
||||
struct MappedResidual<const N: usize, R> {
|
||||
residual: R,
|
||||
indices: [usize; N],
|
||||
}
|
||||
|
||||
impl<const N: usize, R> MappedResidual<N, R> {
|
||||
fn gather_svector(&self, x: &[Scalar]) -> SVector<N> {
|
||||
SVector::from_fn(|_, i| x[self.indices[i]])
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, R: SResidual<N>> IndexedResidual for MappedResidual<N, R> {
|
||||
fn apply(&self, x: &[Scalar]) -> Scalar {
|
||||
self.residual.apply(self.gather_svector(x))
|
||||
}
|
||||
|
||||
fn grad_indices(&self) -> &[usize] {
|
||||
&self.indices
|
||||
}
|
||||
|
||||
fn grad_append(&self, x: &[Scalar], mut grad: &mut Vec<Scalar>) {
|
||||
let gvals = self.residual.grad(self.gather_svector(x));
|
||||
grad.extend(gvals.as_slice());
|
||||
}
|
||||
}
|
||||
|
||||
struct Residuals {
|
||||
residuals: Vec<Box<dyn IndexedResidual>>,
|
||||
}
|
||||
|
||||
impl Residuals {
|
||||
fn apply_fill(&self, x: &[Scalar], residuals: &mut [Scalar]) {
|
||||
assert_eq!(self.residuals.len(), residuals.len());
|
||||
for (resbox, rval) in self.residuals.iter().zip(residuals) {
|
||||
*rval = resbox.apply(x);
|
||||
}
|
||||
}
|
||||
|
||||
// fn grad_fill(&self, x: &[Scalar], grad: &mut [Scalar]) {
|
||||
// assert_eq!(self.residuals.len(), grad.len());
|
||||
// let mut gradcol = Vec::new();
|
||||
// for resbox in self.residuals.iter() {
|
||||
// let res = resbox.apply(x);
|
||||
// let indices = resbox.grad_indices();
|
||||
// gradcol.resize(indices.len(), 0.);
|
||||
// resbox.grad_fill(x, grad);
|
||||
// for i in indices {
|
||||
// grad[i] = 2 * res * gradcol
|
||||
// }
|
||||
// // resbox.grad_fill(x, gradcol.as_mut_slice());
|
||||
// }
|
||||
// }
|
||||
|
||||
fn jacobian(&self, x: &[Scalar]) -> CooMatrix<Scalar> {
|
||||
let gradlen = self
|
||||
.residuals
|
||||
.iter()
|
||||
.map(|res| res.grad_indices().len())
|
||||
.sum::<usize>();
|
||||
let mut row_indices = Vec::with_capacity(gradlen);
|
||||
let mut col_indices = Vec::with_capacity(gradlen);
|
||||
let mut grads = Vec::with_capacity(gradlen);
|
||||
for (j, resbox) in self.residuals.iter().enumerate() {
|
||||
row_indices.extend(resbox.grad_indices());
|
||||
col_indices.extend(std::iter::repeat(j).take(resbox.grad_indices().len()));
|
||||
resbox.grad_append(x, &mut grads);
|
||||
}
|
||||
CooMatrix::try_from_triplets(
|
||||
x.len(),
|
||||
self.residuals.len(),
|
||||
row_indices,
|
||||
col_indices,
|
||||
grads,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
25
src/relations.rs
Normal file
25
src/relations.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use crate::geometry::{PointId, Scalar, PointPos};
|
||||
|
||||
struct SketchRelations {
|
||||
relations: Vec<SketchRelation>,
|
||||
}
|
||||
|
||||
pub enum SketchRelation {
|
||||
Point {
|
||||
p: PointId,
|
||||
rel: PointRelation,
|
||||
},
|
||||
PointPoint {
|
||||
a: PointId,
|
||||
b: PointId,
|
||||
rel: PointPointRelation,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum PointRelation {
|
||||
Fix(PointPos),
|
||||
}
|
||||
|
||||
pub enum PointPointRelation {
|
||||
Distance(Scalar),
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user