Initial commit

This commit is contained in:
Alex Mikhalev 2022-05-22 00:13:30 -07:00
commit 73bc64a625
9 changed files with 3209 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

2433
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

19
Cargo.toml Normal file
View File

@ -0,0 +1,19 @@
[package]
name = "sketchrs"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
argmin = { version = "0.5.1", features = ["nalgebra"] }
eframe = { version = "0.18.0", features = [] }
indexmap = "1.8.1"
levenberg-marquardt = "0.12.0"
nalgebra = "0.30.1"
nalgebra-sparse = "0.6.0"
nohash-hasher = "0.2.0"
[dev-dependencies]
criterion = "0.3.5"

131
src/geometry.rs Normal file
View File

@ -0,0 +1,131 @@
mod entity;
mod var;
use self::entity::{EntityId, EntityMap};
pub use self::var::*;
pub type Scalar = f64;
#[derive(Clone, Debug)]
pub struct PointEntity<V = VarId> {
pub x: V,
pub y: V,
}
impl<V> PointEntity<V> {
pub fn new(x: V, y: V) -> Self {
Self { x, y }
}
}
impl<V> From<(V, V)> for PointEntity<V> {
fn from(p: (V, V)) -> Self {
Self { x: p.0, y: p.1 }
}
}
pub type PointPos = PointEntity<Scalar>;
pub type PointId = EntityId<PointEntity>;
#[derive(Clone, Debug)]
// TODO: segment, ray and full line
// TODO: what if start and end are coincident?
pub struct LineEntity<P = PointId> {
pub start: P,
pub end: P,
}
impl<P> LineEntity<P> {
pub fn new(start: P, end: P) -> Self {
Self { start, end }
}
}
pub type LinePos = LineEntity<PointPos>;
pub type LineId = EntityId<LineEntity>;
#[derive(Clone)]
// TODO: arc. how to represent start and end of arc?
// angles, unit vectors, or a point for each?
pub struct CircleEntity {
pub center: PointId,
pub radius: VarId,
}
// pub type CirclePos = CircleEntity<PointPos>;
pub type CircleId = EntityId<CircleEntity>;
#[derive(Clone, Default)]
pub struct SketchEntities {
vars: EntityMap<Var>,
points: EntityMap<PointEntity>,
lines: EntityMap<LineEntity>,
circles: EntityMap<CircleEntity>,
}
impl SketchEntities {
pub fn vars(&self) -> &EntityMap<Var> {
&self.vars
}
pub fn insert_var(&mut self, var: Var) -> VarId {
self.vars.insert(var)
}
pub fn get_var_mut(&mut self, id: VarId) -> Option<&mut Var> {
self.vars.get_mut(id)
}
pub fn points(&self) -> &EntityMap<PointEntity> {
&self.points
}
pub fn insert_point(&mut self, point: PointEntity) -> PointId {
assert!(self.vars.contains(point.x));
assert!(self.vars.contains(point.y));
self.points.insert(point)
}
pub fn insert_point_at(&mut self, point: impl Into<PointPos>) -> PointId {
let point = point.into();
let x = self.insert_var(Var::new_free(point.x));
let y = self.insert_var(Var::new_free(point.y));
self.insert_point(PointEntity::new(x, y))
}
fn remove_point(&mut self, point_id: PointId) -> bool {
// TODO: check if used by any line
self.points.remove(point_id)
}
pub fn point_pos(&self, point: &PointEntity) -> Var<PointPos> {
// TODO: error handling?
let x = self.vars.get(point.x).unwrap();
let y = self.vars.get(point.y).unwrap();
x.merge(*y, PointPos::new)
}
pub fn lines(&self) -> &EntityMap<LineEntity> {
&self.lines
}
pub fn insert_line(&mut self, line: LineEntity) -> LineId {
assert!(self.points.contains(line.start));
assert!(self.points.contains(line.end));
self.lines.insert(line)
}
fn remove_line(&mut self, line_id: LineId) -> bool {
self.lines.remove(line_id)
}
pub fn line_pos(&self, line: &LineEntity) -> Var<LinePos> {
// TODO: error handling?
let start = self.points.get(line.start).unwrap();
let start = self.point_pos(start);
let end = self.points.get(line.end).unwrap();
let end = self.point_pos(end);
start.merge(end, LinePos::new)
}
}

113
src/geometry/entity.rs Normal file
View File

@ -0,0 +1,113 @@
use std::marker::PhantomData;
use indexmap::IndexMap;
// HashMap with fast no-op hash function and preserves insertion order
type IntIndexMap<K, V> = IndexMap<K, V, nohash_hasher::BuildNoHashHasher<K>>;
pub struct EntityId<T>(u32, PhantomData<T>);
impl<T> Default for EntityId<T> {
fn default() -> Self {
Self(1, Default::default())
}
}
impl<T> nohash_hasher::IsEnabled for EntityId<T> {}
impl<T> EntityId<T> {
fn take_next(&mut self) -> Self {
let id = *self;
self.0 += 1;
id
}
fn invalid() -> Self {
Self(u32::MAX, Default::default())
}
}
#[derive(Clone)]
pub struct EntityMap<T> {
map: IntIndexMap<EntityId<T>, T>,
next_id: EntityId<T>,
}
impl<T> EntityMap<T> {
pub fn insert(&mut self, entity: T) -> EntityId<T> {
let id = self.next_id.take_next();
self.map.insert(id, entity);
id
}
pub fn contains(&self, id: EntityId<T>) -> bool {
self.map.contains_key(&id)
}
pub fn remove(&mut self, id: EntityId<T>) -> bool {
self.map.remove(&id).is_some()
}
pub fn get(&self, id: EntityId<T>) -> Option<&T> {
self.map.get(&id)
}
pub fn get_mut(&mut self, id: EntityId<T>) -> Option<&mut T> {
self.map.get_mut(&id)
}
pub fn iter(&self) -> impl Iterator<Item = (EntityId<T>, &T)> {
self.map.iter().map(|(id, v)| (*id, v))
}
}
// Need to reimplement all of this without the T bound :/
impl<T> PartialEq for EntityId<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0 && self.1 == other.1
}
}
impl<T> Eq for EntityId<T> {}
impl<T> PartialOrd for EntityId<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl<T> Ord for EntityId<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl<T> Clone for EntityId<T> {
fn clone(&self) -> Self {
Self(self.0.clone(), self.1.clone())
}
}
impl<T> Copy for EntityId<T> {}
impl<T> std::hash::Hash for EntityId<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<T> std::fmt::Debug for EntityId<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("EntityId").field(&self.0).finish()
}
}
impl<T> Default for EntityMap<T> {
fn default() -> Self {
Self {
map: Default::default(),
next_id: Default::default(),
}
}
}

62
src/geometry/var.rs Normal file
View File

@ -0,0 +1,62 @@
use super::{EntityId, Scalar};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum VarStatus {
Free,
Dependent,
Unique,
Overconstrained,
}
impl VarStatus {
pub fn merge(self, other: Self) -> Self {
use VarStatus::*;
match (self, other) {
(Free, Free) => Free,
(Unique, Unique) => Unique,
(Overconstrained, _) | (_, Overconstrained) => Overconstrained,
_ => Dependent,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct Var<T = Scalar> {
pub value: T,
pub status: VarStatus,
}
impl<T> Var<T> {
pub fn new_free(value: T) -> Self {
Self {
value,
status: VarStatus::Free,
}
}
pub(super) fn merge<U, V, F>(self, other: Var<U>, by: F) -> Var<V>
where
F: FnOnce(T, U) -> V,
{
Var {
value: by(self.value, other.value),
status: self.status.merge(other.status),
}
}
}
impl<T> std::ops::Deref for Var<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<T> std::ops::DerefMut for Var<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
}
pub type VarId = EntityId<Var>;

303
src/main.rs Normal file
View File

@ -0,0 +1,303 @@
use eframe::{
egui::{self, CursorIcon, Frame, PointerButton, Sense, Ui},
emath::{Pos2, Rect, RectTransform, Vec2},
epaint::{color::Hsva, Color32, Stroke},
};
use geometry::{LineEntity, PointId, PointPos, SketchEntities};
mod geometry;
mod optimization;
mod relations;
fn main() {
let options = eframe::NativeOptions::default();
eframe::run_native(
"sketchrs",
options,
Box::new(|_cc| Box::new(MyApp::default())),
);
}
#[derive(PartialEq)]
enum Tool {
Select,
Move,
AddPoint,
AddLine,
}
impl Default for Tool {
fn default() -> Self {
Tool::Select
}
}
#[derive(Default)]
struct ViewState {
tool: Tool,
hover_point: Option<PointId>,
select_point: Option<PointId>,
drag_delta: Vec2,
}
struct MyApp {
state: ViewState,
entities: SketchEntities,
}
impl Default for MyApp {
fn default() -> Self {
let mut entities = SketchEntities::default();
let p1 = entities.insert_point_at((10., 30.));
let p2 = entities.insert_point_at((-20., 15.));
entities.insert_point_at((0., -10.));
entities.insert_line(LineEntity::new(p1, p2));
Self {
state: ViewState::default(),
entities,
}
}
}
fn color_for_var_status(status: geometry::VarStatus) -> Hsva {
use geometry::VarStatus::*;
match status {
Free => Hsva::new(200. / 360., 0.90, 0.80, 1.0),
Dependent => todo!(),
Unique => todo!(),
Overconstrained => todo!(),
}
}
const POINT_RADIUS: f32 = 3.0;
impl MyApp {
fn show_toolbar(&mut self, ui: &mut Ui) {
ui.heading("sketchrs");
ui.horizontal(|ui| {
ui.label("Tool: ");
let mut tool = &mut self.state.tool;
ui.selectable_value(tool, Tool::Select, "Select");
ui.selectable_value(tool, Tool::Move, "Move");
ui.selectable_value(tool, Tool::AddPoint, "+ Point");
ui.selectable_value(tool, Tool::AddLine, "+ Line");
});
}
fn show_entities(&mut self, ui: &mut Ui) {
let mut state = &mut self.state;
let sense = match state.tool {
Tool::Move => Sense::drag(),
Tool::Select | Tool::AddPoint | Tool::AddLine => Sense::click(),
};
let (response, painter) = ui.allocate_painter(ui.available_size(), sense);
let ctx = response.ctx.clone();
let to_screen = RectTransform::from_to(
Rect::from_center_size(Pos2::ZERO, response.rect.size() / 2.0),
response.rect,
);
let transform_pos = |pos: &PointPos| to_screen * Pos2::new(pos.x as f32, pos.y as f32);
let mut hover_line = None;
state.hover_point = None;
if let Some(hover_pos) = response.hover_pos() {
for (id, point) in self
.entities
.points()
.iter()
.filter(|(id, _)| Some(*id) != state.select_point)
{
let pos = self.entities.point_pos(point);
let center = transform_pos(&*pos);
if (hover_pos - center).length() < (POINT_RADIUS * 3.) {
state.hover_point = Some(id);
break;
}
}
for (id, line) in self.entities.lines().iter() {
let pos = self.entities.line_pos(line);
let points = [transform_pos(&pos.start), transform_pos(&pos.end)];
let b = points[1] - points[0];
let a = hover_pos - points[0];
let p = a.dot(b) / b.dot(b);
let perp = a - (p * b);
let hovered = ((0.)..=(1.)).contains(&p) && perp.length() < 5.0;
if hovered {
hover_line = Some(id);
break;
}
}
ctx.output().cursor_icon = match state.tool {
Tool::Select => {
if state.hover_point.is_some() {
CursorIcon::PointingHand
} else {
CursorIcon::Default
}
}
Tool::Move => {
if state.select_point.is_some() {
CursorIcon::Grabbing
} else if state.hover_point.is_some() {
CursorIcon::Grab
} else {
CursorIcon::Move
}
}
Tool::AddPoint => CursorIcon::None,
Tool::AddLine => CursorIcon::None,
};
let mut add_point = || {
let point_pos = to_screen.inverse() * hover_pos;
let point_pos = (point_pos.x as f64, point_pos.y as f64);
self.entities.insert_point_at(point_pos)
};
match state.tool {
Tool::Select => {
if response.clicked() {
state.select_point = state.hover_point;
}
}
Tool::Move => {
let drag_started = response.drag_started();
if drag_started {
state.select_point = state.hover_point;
} else if response.drag_released() {
state.select_point = None;
state.drag_delta = Vec2::ZERO;
}
if let Some(point_id) = state.select_point {
let point = self.entities.points().get(point_id).unwrap();
if drag_started {
let point_pos = self.entities.point_pos(point);
state.drag_delta = hover_pos - transform_pos(&*point_pos);
}
let move_to = to_screen.inverse() * (hover_pos - state.drag_delta);
for (id, val) in [(point.x, move_to.x), (point.y, move_to.y)] {
self.entities.get_var_mut(id).unwrap().value = val as f64;
}
}
}
Tool::AddPoint => {
if response.clicked() {
add_point();
} else {
painter.circle_filled(hover_pos, POINT_RADIUS, Color32::WHITE);
}
}
Tool::AddLine => {
match (state.select_point, response.clicked()) {
(Some(start_point_id), true) => {
// TODO: add point if no hover point
let end_point = state.hover_point.unwrap_or_else(add_point);
let line = LineEntity::new(start_point_id, end_point);
self.entities.insert_line(line);
state.select_point = None;
}
(None, true) => {
state.select_point = Some(state.hover_point.unwrap_or_else(add_point));
}
(Some(first_point_id), false) => {
let first_point = self.entities.points().get(first_point_id).unwrap();
let first_point_pos = self.entities.point_pos(first_point);
let points = [transform_pos(&first_point_pos), hover_pos];
let stroke = Stroke::new(2.0, Color32::DARK_GRAY);
painter.line_segment(points, stroke);
painter.circle_filled(hover_pos, POINT_RADIUS, Color32::DARK_GRAY);
}
(None, false) => {
painter.circle_filled(hover_pos, POINT_RADIUS, Color32::DARK_GRAY);
}
}
}
}
}
for (id, line) in self.entities.lines().iter() {
let pos = self.entities.line_pos(line);
let points = [transform_pos(&pos.start), transform_pos(&pos.end)];
let mut color = color_for_var_status(pos.status);
color.v -= 0.6;
if state.hover_point.is_none()
&& (state.select_point.is_none() || state.tool != Tool::Move)
&& Some(id) == hover_line
{
color.s -= 0.8;
}
let stroke = Stroke::new(2.0, color);
painter.line_segment(points, stroke);
}
for (id, point) in self.entities.points().iter() {
let pos = self.entities.point_pos(point);
let center = transform_pos(&*pos);
let mut color = color_for_var_status(pos.status);
let stroke = match (state.select_point, state.hover_point) {
(Some(sid), _) | (_, Some(sid)) if id == sid => {
// color.s -= 0.8;
Stroke::new(1.0, Color32::WHITE)
}
_ => Stroke::default(),
};
painter.circle(center, POINT_RADIUS, color, stroke);
}
}
}
impl eframe::App for MyApp {
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
ctx.request_repaint();
ctx.set_visuals(egui::Visuals::dark());
egui::TopBottomPanel::top("top_panel").show(ctx, |ui| self.show_toolbar(ui));
egui::CentralPanel::default()
.frame(egui::Frame::none())
.show(ctx, |ui| {
self.show_entities(ui);
});
egui::SidePanel::right("right_panel")
.resizable(true)
.default_width(150.0)
.width_range(80.0..=200.0)
.show(ctx, |ui| {
ui.vertical_centered(|ui| {
ui.heading("Left Panel");
});
// egui::ScrollArea::vertical().show(ui, |ui| {
// lorem_ipsum(ui);
// });
});
egui::TopBottomPanel::bottom("bottom_panel").show(ctx, |ui| {
ui.horizontal(|ui| {
ui.label("Status: ");
});
});
}
}

122
src/optimization.rs Normal file
View File

@ -0,0 +1,122 @@
use std::mem::MaybeUninit;
use nalgebra::{DMatrix, DVector, DimName};
use nalgebra_sparse::CooMatrix;
use crate::geometry::Scalar;
type SVector<const D: usize> = nalgebra::SVector<Scalar, D>;
trait SResidual<const N: usize> {
fn apply(&self, x: SVector<N>) -> Scalar;
fn grad(&self, x: SVector<N>) -> SVector<N>;
}
struct PointPointDistance(Scalar);
fn square(x: Scalar) -> Scalar {
x * x
}
impl SResidual<4> for PointPointDistance {
fn apply(&self, x: SVector<4>) -> Scalar {
let [[x1, y1, x2, y2]] = x.data.0;
// x.generic_slice()
square(x2 - x1) + square(y2 - y1) - square(self.0)
// (square(x2 - x1) + square(y2 - y1)).sqrt() - self.0
}
fn grad(&self, x: SVector<4>) -> SVector<4> {
let [[x1, y1, x2, y2]] = x.data.0;
SVector::<4>::new(
2. * (x1 - x2),
2. * (y1 - y2),
2. * (x2 - x1),
2. * (y2 - x1),
)
}
}
trait IndexedResidual {
fn apply(&self, x: &[Scalar]) -> Scalar;
fn grad_indices(&self) -> &[usize];
fn grad_append(&self, x: &[Scalar], grad: &mut Vec<Scalar>);
}
struct MappedResidual<const N: usize, R> {
residual: R,
indices: [usize; N],
}
impl<const N: usize, R> MappedResidual<N, R> {
fn gather_svector(&self, x: &[Scalar]) -> SVector<N> {
SVector::from_fn(|_, i| x[self.indices[i]])
}
}
impl<const N: usize, R: SResidual<N>> IndexedResidual for MappedResidual<N, R> {
fn apply(&self, x: &[Scalar]) -> Scalar {
self.residual.apply(self.gather_svector(x))
}
fn grad_indices(&self) -> &[usize] {
&self.indices
}
fn grad_append(&self, x: &[Scalar], mut grad: &mut Vec<Scalar>) {
let gvals = self.residual.grad(self.gather_svector(x));
grad.extend(gvals.as_slice());
}
}
struct Residuals {
residuals: Vec<Box<dyn IndexedResidual>>,
}
impl Residuals {
fn apply_fill(&self, x: &[Scalar], residuals: &mut [Scalar]) {
assert_eq!(self.residuals.len(), residuals.len());
for (resbox, rval) in self.residuals.iter().zip(residuals) {
*rval = resbox.apply(x);
}
}
// fn grad_fill(&self, x: &[Scalar], grad: &mut [Scalar]) {
// assert_eq!(self.residuals.len(), grad.len());
// let mut gradcol = Vec::new();
// for resbox in self.residuals.iter() {
// let res = resbox.apply(x);
// let indices = resbox.grad_indices();
// gradcol.resize(indices.len(), 0.);
// resbox.grad_fill(x, grad);
// for i in indices {
// grad[i] = 2 * res * gradcol
// }
// // resbox.grad_fill(x, gradcol.as_mut_slice());
// }
// }
fn jacobian(&self, x: &[Scalar]) -> CooMatrix<Scalar> {
let gradlen = self
.residuals
.iter()
.map(|res| res.grad_indices().len())
.sum::<usize>();
let mut row_indices = Vec::with_capacity(gradlen);
let mut col_indices = Vec::with_capacity(gradlen);
let mut grads = Vec::with_capacity(gradlen);
for (j, resbox) in self.residuals.iter().enumerate() {
row_indices.extend(resbox.grad_indices());
col_indices.extend(std::iter::repeat(j).take(resbox.grad_indices().len()));
resbox.grad_append(x, &mut grads);
}
CooMatrix::try_from_triplets(
x.len(),
self.residuals.len(),
row_indices,
col_indices,
grads,
)
.unwrap()
}
}

25
src/relations.rs Normal file
View File

@ -0,0 +1,25 @@
use crate::geometry::{PointId, Scalar, PointPos};
struct SketchRelations {
relations: Vec<SketchRelation>,
}
pub enum SketchRelation {
Point {
p: PointId,
rel: PointRelation,
},
PointPoint {
a: PointId,
b: PointId,
rel: PointPointRelation,
},
}
pub enum PointRelation {
Fix(PointPos),
}
pub enum PointPointRelation {
Distance(Scalar),
}