Browse Source

improve math solver

master
Alex Mikhalev 6 years ago
parent
commit
a63f063a05
  1. 194
      Cargo.lock
  2. 5
      Cargo.toml
  3. 10
      src/main.rs
  4. 532
      src/math.rs

194
Cargo.lock generated

@ -1,3 +1,11 @@ @@ -1,3 +1,11 @@
[[package]]
name = "aho-corasick"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"memchr 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "alga"
version = "0.7.2"
@ -17,6 +25,16 @@ dependencies = [ @@ -17,6 +25,16 @@ dependencies = [
"num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "atty"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "bitflags"
version = "1.0.4"
@ -27,9 +45,17 @@ name = "cad_rs" @@ -27,9 +45,17 @@ name = "cad_rs"
version = "0.1.0"
dependencies = [
"approx 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"env_logger 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
"itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"nalgebra 0.16.13 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "cfg-if"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "cloudabi"
version = "0.0.3"
@ -38,6 +64,23 @@ dependencies = [ @@ -38,6 +64,23 @@ dependencies = [
"bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "either"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "env_logger"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)",
"humantime 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "fuchsia-cprng"
version = "0.1.0"
@ -51,6 +94,27 @@ dependencies = [ @@ -51,6 +94,27 @@ dependencies = [
"typenum 1.10.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "humantime"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "itertools"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "lazy_static"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "libc"
version = "0.2.48"
@ -61,6 +125,14 @@ name = "libm" @@ -61,6 +125,14 @@ name = "libm"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "log"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "matrixmultiply"
version = "0.1.15"
@ -69,6 +141,15 @@ dependencies = [ @@ -69,6 +141,15 @@ dependencies = [
"rawpointer 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "memchr"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "nalgebra"
version = "0.16.13"
@ -97,6 +178,11 @@ name = "num-traits" @@ -97,6 +178,11 @@ name = "num-traits"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "quick-error"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rand"
version = "0.5.6"
@ -127,11 +213,80 @@ name = "rawpointer" @@ -127,11 +213,80 @@ name = "rawpointer"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "redox_syscall"
version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "redox_termios"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"redox_syscall 0.1.51 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "regex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"aho-corasick 0.6.9 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
"thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"utf8-ranges 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "regex-syntax"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"ucd-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "termcolor"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"wincolor 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "termion"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)",
"redox_syscall 0.1.51 (registry+https://github.com/rust-lang/crates.io-index)",
"redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "thread_local"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"lazy_static 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "typenum"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ucd-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "utf8-ranges"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "winapi"
version = "0.3.6"
@ -146,29 +301,68 @@ name = "winapi-i686-pc-windows-gnu" @@ -146,29 +301,68 @@ name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "winapi-util"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wincolor"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi-util 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[metadata]
"checksum aho-corasick 0.6.9 (registry+https://github.com/rust-lang/crates.io-index)" = "1e9a933f4e58658d7b12defcf96dc5c720f20832deebe3e0a19efd3b6aaeeb9e"
"checksum alga 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "24bb00eeca59f2986c747b8c2f271d52310ce446be27428fc34705138b155778"
"checksum approx 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3c57ff1a5b00753647aebbbcf4ea67fa1e711a65ea7a30eb90dbf07de2485aee"
"checksum atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "9a7d5b8723950951411ee34d271d99dddcc2035a16ab25310ea2c8cfd4369652"
"checksum bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "228047a76f468627ca71776ecdebd732a3423081fcf5125585bcd7c49886ce12"
"checksum cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "082bb9b28e00d3c9d39cc03e64ce4cea0f1bb9b3fde493f0cbc008472d22bdf4"
"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f"
"checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0"
"checksum env_logger 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "afb070faf94c85d17d50ca44f6ad076bce18ae92f0037d350947240a36e9d42e"
"checksum fuchsia-cprng 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "81f7f8eb465745ea9b02e2704612a9946a59fa40572086c6fd49d6ddcf30bf31"
"checksum generic-array 0.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8107dafa78c80c848b71b60133954b4a58609a3a1a5f9af037ecc7f67280f369"
"checksum humantime 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3ca7e5f2e110db35f93b837c81797f3714500b81d517bf20c431b16d3ca4f114"
"checksum itertools 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5b8467d9c1cebe26feb08c640139247fac215782d35371ade9a2136ed6085358"
"checksum lazy_static 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a374c89b9db55895453a74c1e38861d9deec0b01b405a82516e9d5de4820dea1"
"checksum libc 0.2.48 (registry+https://github.com/rust-lang/crates.io-index)" = "e962c7641008ac010fa60a7dfdc1712449f29c44ef2d4702394aea943ee75047"
"checksum libm 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "03c0bb6d5ce1b5cc6fd0578ec1cbc18c9d88b5b591a5c7c1d6c6175e266a0819"
"checksum log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c84ec4b527950aa83a329754b01dbe3f58361d1c5efacd1f6d68c494d08a17c6"
"checksum matrixmultiply 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "dcad67dcec2d58ff56f6292582377e6921afdf3bfbd533e26fb8900ae575e002"
"checksum memchr 2.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "e1dd4eaac298c32ce07eb6ed9242eda7d82955b9170b7d6db59b2e02cc63fcb8"
"checksum nalgebra 0.16.13 (registry+https://github.com/rust-lang/crates.io-index)" = "8e0799b53947b9c9048a1537f024f22f54701bbb75274f65955d081a87c0b739"
"checksum num-complex 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "107b9be86cd2481930688277b675b0114578227f034674726605b8a482d8baf8"
"checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1"
"checksum quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "9274b940887ce9addde99c4eee6b5c44cc494b182b97e73dc8ffdcb3397fd3f0"
"checksum rand 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c618c47cd3ebd209790115ab837de41425723956ad3ce2e6a7f09890947cacb9"
"checksum rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b"
"checksum rand_core 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d0e7a549d590831370895ab7ba4ea0c1b6b011d106b5ff2da6eee112615e6dc0"
"checksum rawpointer 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ebac11a9d2e11f2af219b8b8d833b76b1ea0e054aa0e8d8e9e4cbde353bdf019"
"checksum redox_syscall 0.1.51 (registry+https://github.com/rust-lang/crates.io-index)" = "423e376fffca3dfa06c9e9790a9ccd282fafb3cc6e6397d01dbf64f9bacc6b85"
"checksum redox_termios 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e891cfe48e9100a70a3b6eb652fef28920c117d366339687bd5576160db0f76"
"checksum regex 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "37e7cbbd370869ce2e8dff25c7018702d10b21a20ef7135316f8daecd6c25b7f"
"checksum regex-syntax 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)" = "8c2f35eedad5295fdf00a63d7d4b238135723f92b434ec06774dad15c7ab0861"
"checksum termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4096add70612622289f2fdcdbd5086dc81c1e2675e6ae58d6c4f62a16c6d7f2f"
"checksum termion 1.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "689a3bdfaab439fd92bc87df5c4c78417d3cbe537487274e9b0b2dce76e92096"
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
"checksum typenum 1.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "612d636f949607bdf9b123b4a6f6d966dedf3ff669f7f045890d3a4a73948169"
"checksum ucd-util 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "535c204ee4d8434478593480b8f86ab45ec9aae0e83c568ca81abf0fd0e88f86"
"checksum utf8-ranges 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "796f7e48bef87609f7ade7e06495a87d5cd06c7866e6a5cbfceffc558a243737"
"checksum winapi 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "92c1eb33641e276cfa214a0522acad57be5c56b10cb348b3c5117db75f3ac4b0"
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
"checksum winapi-util 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7168bab6e1daee33b4557efd0e95d5ca70a03706d39fa5f3fe7a236f584b03c9"
"checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
"checksum wincolor 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "561ed901ae465d6185fa7864d63fbd5720d0ef718366c9a4dc83cf6170d7e9ba"

5
Cargo.toml

@ -6,4 +6,7 @@ edition = "2018" @@ -6,4 +6,7 @@ edition = "2018"
[dependencies]
nalgebra = "0.16"
approx = "0.3"
approx = "0.3"
itertools = "0.8"
log = "0.4"
env_logger = "0.6"

10
src/main.rs

@ -1,8 +1,16 @@ @@ -1,8 +1,16 @@
#![feature(drain_filter)]
#![feature(box_patterns)]
#![feature(slice_patterns)]
#![feature(bind_by_move_pattern_guards)]
#![feature(vec_remove_item)]
extern crate nalgebra;
#[macro_use]
extern crate approx;
extern crate itertools;
#[macro_use]
extern crate log;
extern crate env_logger;
mod entity;
mod math;
@ -13,6 +21,8 @@ fn main() { @@ -13,6 +21,8 @@ fn main() {
use math::Point2;
use relation::{Relation, ResolveResult};
env_logger::init();
println!("Hello, world!");
let origin = Point::new_ref(Var::new_single(Point2::new(0., 0.)));
let p1 = Point::new_ref(Var::new_full(Point2::new(1., 1.)));

532
src/math.rs

@ -214,9 +214,9 @@ impl Region2 { @@ -214,9 +214,9 @@ impl Region2 {
}
}
mod solve {
pub mod solve {
use std::collections::BTreeSet;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::iter::FromIterator;
@ -268,10 +268,179 @@ mod solve { @@ -268,10 +268,179 @@ mod solve {
enum Expr {
Unkn(Unknown),
Const(Scalar),
Plus(Box<Expr>, Box<Expr>),
Sum(Exprs),
Neg(Box<Expr>),
Times(Box<Expr>, Box<Expr>),
Inv(Box<Expr>),
Product(Exprs),
Div(Box<Expr>, Box<Expr>),
}
type Exprs = Vec<Expr>;
impl Unknowns for Exprs {
fn unknowns(&self) -> UnknownSet {
self.iter().flat_map(|e: &Expr| e.unknowns()).collect()
}
fn has_unknowns(&self) -> bool {
self.iter().any(|e: &Expr| e.has_unknowns())
}
fn has_unknown(&self, u: Unknown) -> bool {
self.iter().any(|e: &Expr| e.has_unknown(u))
}
}
fn write_separated_exprs(es: &Exprs, f: &mut fmt::Formatter, sep: &str) -> fmt::Result {
let mut is_first = true;
for e in es {
if is_first {
is_first = false;
} else {
write!(f, "{}", sep)?
}
write!(f, "({})", e)?
}
Ok(())
}
fn remove_common_terms(l: &mut Vec<Expr>, r: &mut Vec<Expr>) -> Vec<Expr> {
let common: Vec<_> = l.drain_filter(|e| r.contains(e)).collect();
common.iter().for_each(|e| {
r.remove_item(e);
});
common
}
fn remove_term(terms: &mut Vec<Expr>, term: &Expr) -> Option<Expr> {
terms.remove_item(term)
}
fn sum_fold(l: Expr, r: Expr) -> Expr {
use itertools::Itertools;
use Expr::*;
match (l, r) {
(Const(lc), Const(rc)) => Const(lc + rc),
(Const(c), o) | (o, Const(c)) if relative_eq!(c, 0.) => o,
(Product(mut l), Product(mut r)) => {
let comm = remove_common_terms(&mut l, &mut r);
Expr::new_product(Sum(comm), Expr::new_sum(Product(l), Product(r))).simplify()
}
(Product(mut l), r) | (r, Product(mut l)) => {
let comm = remove_term(&mut l, &r);
match comm {
Some(_) => {
Expr::new_product(r, Expr::new_sum(Product(l), Const(1.))).simplify()
}
None => Expr::new_sum(Product(l), r),
}
}
(l, r) => Expr::new_sum(l, r),
}
}
fn group_sum(es: Exprs) -> Exprs {
use Expr::*;
let mut common: BTreeMap<UnknownSet, Expr> = BTreeMap::new();
for e in es {
let unkns = e.unknowns();
match common.get_mut(&unkns) {
None => {
match e {
Const(c) if relative_eq!(c, 0.) => (),
e => {
common.insert(unkns, e);
}
};
}
Some(existing) => {
match existing {
Sum(ref mut es) => {
// already failed at merging, so just add it to the list
es.push(e);
}
other => {
*other = sum_fold(other.clone(), e);
}
};
}
};
}
common.into_iter().map(|(_, v)| v).collect()
}
fn product_fold(l: Expr, r: Expr) -> Expr {
use itertools::Itertools;
use Expr::*;
match (l, r) {
(Const(lc), Const(rc)) => Const(lc * rc),
(Const(c), o) | (o, Const(c)) if relative_eq!(c, 1.) => o,
(Div(num, den), mul) | (mul, Div(num, den)) => {
if mul == *den {
*num
} else {
Expr::Div(Box::new(Expr::Product(vec![*num, mul])), den).simplify()
}
}
(l, r) => Expr::new_product(l, r),
}
}
fn group_product(es: Exprs) -> Exprs {
use Expr::*;
let mut common: BTreeMap<UnknownSet, Expr> = BTreeMap::new();
for e in es {
let unkns = e.unknowns();
match common.get_mut(&unkns) {
None => {
match e {
Const(c) if relative_eq!(c, 1.) => (),
e => {
common.insert(unkns, e);
}
};
}
Some(existing) => {
match existing {
Sum(ref mut es) => {
// already failed at merging, so just add it to the list
es.push(e);
}
other => *other = product_fold(other.clone(), e),
};
}
};
}
common.into_iter().map(|(_, v)| v).collect()
}
fn distribute_product_sums(mut es: Exprs) -> Expr {
trace!("distribute_product_sums: {}", Product(es.clone()));
use itertools::Itertools;
use Expr::*;
let sums = es
.drain_filter(|e| match e {
Sum(_) => true,
_ => false,
})
.map(|e| {
trace!("sum in product: {}", e);
match e {
Sum(es) => es,
_ => unreachable!(),
}
});
let products: Vec<_> = sums.multi_cartesian_product().collect();
if products.is_empty() {
trace!("no sums to distribute");
return Product(es);
}
let sums = products
.into_iter()
.map(|mut prod| {
prod.extend(es.clone());
trace!("prod: {}", Product(prod.clone()));
Product(prod)
})
.collect();
Sum(sums)
}
impl Unknowns for Expr {
@ -280,8 +449,9 @@ mod solve { @@ -280,8 +449,9 @@ mod solve {
match self {
Unkn(u) => u.unknowns(),
Const(_) => UnknownSet::default(),
Plus(l, r) | Times(l, r) => l.unknowns().union(&r.unknowns()).cloned().collect(),
Neg(e) | Inv(e) => e.unknowns(),
Sum(es) | Product(es) => es.unknowns(),
Div(l, r) => l.unknowns().union(&r.unknowns()).cloned().collect(),
Neg(e) => e.unknowns(),
}
}
fn has_unknowns(&self) -> bool {
@ -289,8 +459,9 @@ mod solve { @@ -289,8 +459,9 @@ mod solve {
match self {
Unkn(u) => u.has_unknowns(),
Const(_) => false,
Plus(l, r) | Times(l, r) => l.has_unknowns() || r.has_unknowns(),
Neg(e) | Inv(e) => e.has_unknowns(),
Sum(es) | Product(es) => es.has_unknowns(),
Div(l, r) => l.has_unknowns() || r.has_unknowns(),
Neg(e) => e.has_unknowns(),
}
}
fn has_unknown(&self, u: Unknown) -> bool {
@ -298,30 +469,31 @@ mod solve { @@ -298,30 +469,31 @@ mod solve {
match self {
Unkn(u1) => u1.has_unknown(u),
Const(_) => false,
Plus(l, r) | Times(l, r) => l.has_unknown(u) || r.has_unknown(u),
Neg(e) | Inv(e) => e.has_unknown(u),
Sum(es) | Product(es) => es.has_unknown(u),
Div(l, r) => l.has_unknown(u) || r.has_unknown(u),
Neg(e) => e.has_unknown(u),
}
}
}
impl Expr {
fn new_plus(e1: Expr, e2: Expr) -> Expr {
Expr::Plus(Box::new(e1), Box::new(e2))
fn new_sum(e1: Expr, e2: Expr) -> Expr {
Expr::Sum(vec![e1, e2])
}
fn new_times(e1: Expr, e2: Expr) -> Expr {
Expr::Times(Box::new(e1), Box::new(e2))
fn new_product(e1: Expr, e2: Expr) -> Expr {
Expr::Product(vec![e1, e2])
}
fn new_neg(e1: Expr) -> Expr {
Expr::Neg(Box::new(e1))
}
fn new_inv(e1: Expr) -> Expr {
Expr::Inv(Box::new(e1))
fn new_div(num: Expr, den: Expr) -> Expr {
Expr::Div(Box::new(num), Box::new(den))
}
fn new_minus(e1: Expr, e2: Expr) -> Expr {
Expr::Plus(Box::new(e1), Box::new(Expr::new_neg(e2)))
Expr::Sum(vec![e1, Expr::new_neg(e2)])
}
fn new_div(e1: Expr, e2: Expr) -> Expr {
Expr::Times(Box::new(e1), Box::new(Expr::new_inv(e2)))
fn new_inv(den: Expr) -> Expr {
Expr::new_div(Expr::Const(1.), den)
}
fn is_zero(&self) -> bool {
@ -343,90 +515,132 @@ mod solve { @@ -343,90 +515,132 @@ mod solve {
fn simplify(self) -> Expr {
use Expr::*;
match self {
Plus(l, r) => match (l.simplify(), r.simplify()) {
(Const(lc), Const(rc)) => Const(lc + rc),
(Const(c), ref o) | (ref o, Const(c)) if relative_eq!(c, 0.) => o.clone(),
(Times(l1, l2), Times(r1, r2)) => {
if l2 == r2 {
Expr::new_times(Expr::Plus(l1, r1), *l2).simplify()
} else if l1 == r1 {
Expr::new_times(Expr::Plus(l2, r2), *l1).simplify()
} else if l1 == r2 {
Expr::new_times(Expr::Plus(l2, r1), *l1).simplify()
} else if l2 == r1 {
Expr::new_times(Expr::Plus(l1, r2), *l2).simplify()
} else {
Expr::new_plus(Times(l1, l2), Times(r1, r2))
Sum(es) => {
let mut new_es: Vec<_> = es
.into_iter()
.map(|e| e.simplify())
.flat_map(|e| match e {
Sum(more_es) => more_es,
other => vec![other],
})
.collect();
let pre_new_es = new_es.clone();
new_es = group_sum(new_es);
trace!(
"simplify sum {} => {}",
Sum(pre_new_es),
Sum(new_es.clone())
);
match new_es.len() {
0 => Const(0.), // none
1 => new_es.into_iter().next().unwrap(), // one
_ => Sum(new_es), // many
}
}
Product(es) => {
let new_es: Vec<_> = es
.into_iter()
.map(|e| e.simplify())
.flat_map(|e| match e {
Product(more_es) => more_es,
other => vec![other],
})
.collect();
let pre_new_es = new_es.clone();
let new_es = group_product(new_es);
trace!(
"simplify product {} => {}",
Product(pre_new_es),
Product(new_es.clone())
);
match new_es.len() {
0 => Const(1.), // none
1 => new_es.into_iter().next().unwrap(), // one
_ => Product(new_es), // many
}
}
Neg(mut v) => {
*v = v.simplify();
trace!("simplify neg {}", Neg(v.clone()));
match v {
box Const(c) => Const(-c),
box Neg(v) => *v,
e => Product(vec![Const(-1.), *e]),
}
}
Div(mut num, mut den) => {
*num = num.simplify();
*den = den.simplify();
trace!("simplify div {}", Div(num.clone(), den.clone()));
match (num, den) {
(box Const(num), box Const(den)) => Const(num / den),
(num, box Const(den)) => {
if relative_eq!(den, 1.) {
*num
} else {
Expr::new_product(*num, Const(1. / den))
}
}
(num, box Div(dennum, denden)) => {
Div(Box::new(Product(vec![*num, *denden])), dennum).simplify()
}
(box Product(mut es), box den) => match es.remove_item(&den) {
Some(_) => Product(es),
None => Expr::new_div(Product(es), den),
},
(num, den) => {
if num == den {
Expr::Const(1.)
} else {
Div(num, den)
}
}
}
(l, r) => Self::new_plus(l, r),
},
Times(l, r) => match (l.simplify(), r.simplify()) {
(Const(lc), Const(rc)) => Const(lc * rc),
(Const(c), ref o) | (ref o, Const(c)) if relative_eq!(c, 1.) => o.clone(),
(Inv(ref den), ref num) | (ref num, Inv(ref den)) if *num == **den => Const(1.),
(l, r) => Self::new_times(l, r),
},
Neg(v) => match v.simplify() {
Const(c) => Const(-c),
Neg(v) => *v,
e => Self::new_times(Const(-1.), e),
},
Inv(v) => match v.simplify() {
Const(c) => Const(1. / c),
Inv(v) => *v,
e => Self::new_inv(e),
},
e => e,
}
}
fn distrubute(self) -> Expr {
use Expr::*;
match self {
Plus(l, r) => Expr::new_plus(l.distrubute(), r.distrubute()),
Times(l, r) => match (*l, *r) {
(Plus(l, r), o) | (o, Plus(l, r)) => Expr::new_plus(
Expr::Times(Box::new(o.clone()), l),
Expr::Times(Box::new(o), r),
)
.distrubute(),
(l, r) => Expr::new_times(l, r),
},
Neg(v) => match *v {
Plus(l, r) => Expr::new_plus(Neg(l).distrubute(), Neg(r).distrubute()),
Times(l, r) => Expr::new_times(Neg(l).distrubute(), *r),
Neg(v) => v.distrubute(),
Inv(v) => Expr::new_inv(Neg(v).distrubute()),
e => Expr::new_neg(e),
},
Inv(v) => match *v {
Plus(l, r) => Expr::new_plus(Inv(l).distrubute(), Inv(r).distrubute()),
Times(l, r) => Expr::new_times(Inv(l).distrubute(), Inv(r).distrubute()),
Inv(v) => v.distrubute(),
e => Expr::new_inv(e),
},
}
e => e,
}
}
fn reduce(self, for_u: Unknown) -> Expr {
fn distribute(self) -> Expr {
use Expr::*;
trace!("distribute {}", self);
match self {
Plus(l, r) => match (l.reduce(for_u), r.reduce(for_u)) {
(Const(lc), Const(rc)) => Const(lc + rc),
(l, r) => Self::new_plus(l, r),
},
Times(l, r) => match (l.reduce(for_u), r.reduce(for_u)) {
(Const(lc), Const(rc)) => Const(lc * rc),
(l, r) => Self::new_times(l, r),
},
Neg(v) => match v.reduce(for_u) {
Unkn(u) if u == for_u => Expr::new_times(Const(-1.), Unkn(u)),
e => Self::new_neg(e),
},
Inv(v) => match v.reduce(for_u) {
e => Self::new_inv(e),
Sum(mut es) => {
for e in &mut es {
*e = e.clone().distribute();
}
Sum(es)
}
Product(es) => distribute_product_sums(es),
Div(mut num, mut den) => {
*num = num.distribute();
*den = den.distribute();
match (num, den) {
(box Sum(es), box den) => Sum(es
.into_iter()
.map(|e| Expr::new_div(e, den.clone()))
.collect()),
(mut num, mut den) => Div(num, den),
}
}
Neg(v) => match v {
// box Sum(mut l, mut r) => {
// *l = Neg(l.clone()).distribute();
// *r = Neg(r.clone()).distribute();
// Sum(l, r)
// }
// box Product(mut l, r) => {
// *l = Neg(l.clone()).distribute();
// Product(l, r)
// }
box Neg(v) => v.distribute(),
box Div(mut num, mut den) => {
*num = Neg(num.clone()).distribute();
*den = Neg(den.clone()).distribute();
Div(num, den)
}
e => Neg(e),
},
e => e,
}
@ -439,10 +653,10 @@ mod solve { @@ -439,10 +653,10 @@ mod solve {
match self {
Unkn(u) => write!(f, "{}", u),
Const(c) => write!(f, "{}", c),
Plus(l, r) => write!(f, "({}) + ({})", l, r),
Times(l, r) => write!(f, "({}) * ({})", l, r),
Sum(es) => write_separated_exprs(es, f, " + "),
Product(es) => write_separated_exprs(es, f, " * "),
Div(num, den) => write!(f, "({}) / ({})", num, den),
Neg(e) => write!(f, "-({})", e),
Inv(e) => write!(f, "1 / ({})", e),
}
}
}
@ -477,27 +691,62 @@ mod solve { @@ -477,27 +691,62 @@ mod solve {
}
impl Eqn {
fn simplify(self) -> Eqn {
Eqn(self.0.simplify(), self.1.simplify())
}
fn solve(&self, for_u: Unknown) -> Option<Expr> {
use Expr::*;
if !self.has_unknown(for_u) {
return None;
}
let (l, r) = (self.0.clone().simplify(), self.1.clone().simplify());
let (l, r) = (
self.0
.clone() /*.distribute()*/
.simplify(),
self.1
.clone() /*.distribute()*/
.simplify(),
);
let (mut l, mut r) = ord_by_unkn(l, r, for_u)?;
loop {
trace!("solve: {} == {}", l, r);
let (new_l, new_r): (Expr, Expr) = match l {
Unkn(u) => return if u == for_u { Some(r.simplify()) } else { None },
Plus(a, b) => {
let (a, b) = ord_by_unkn(*a, *b, for_u)?;
(a, Expr::new_minus(r, b))
Sum(es) => {
let (us, not_us): (Vec<_>, Vec<_>) =
es.into_iter().partition(|e| e.has_unknown(for_u));
if us.len() != 1 {
return None;
}
(
Sum(us).simplify(),
Expr::new_minus(r, Sum(not_us)).simplify(),
)
}
Times(a, b) => {
let (a, b) = ord_by_unkn(*a, *b, for_u)?;
(a, Expr::new_div(r, b))
Product(es) => {
let (us, not_us): (Vec<_>, Vec<_>) =
es.into_iter().partition(|e| e.has_unknown(for_u));
if us.len() != 1 {
return None;
}
(
Product(us).simplify(),
Expr::new_div(r, Product(not_us)).simplify(),
)
}
Neg(v) => (*v, Expr::new_neg(r)),
Inv(v) => (*v, Expr::new_inv(r)),
Div(num, den) => {
let (nu, du) = (num.has_unknown(for_u), den.has_unknown(for_u));
match (nu, du) {
(true, false) => (*num, Expr::new_product(r, *den)),
(false, true) => (Expr::new_product(r, *den), *num),
(true, true) => return None, // TODO: simplify
(false, false) => return None,
}
}
Const(_) => return None,
_ => return None,
};
l = new_l;
r = new_r;
@ -505,6 +754,12 @@ mod solve { @@ -505,6 +754,12 @@ mod solve {
}
}
impl fmt::Display for Eqn {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} == {}", self.0, self.1)
}
}
#[derive(Clone, Debug, PartialEq)]
struct Eqns(Vec<Eqn>);
@ -554,6 +809,7 @@ mod solve { @@ -554,6 +809,7 @@ mod solve {
#[test]
fn test_solve() {
use Expr::*;
let _ = env_logger::try_init();
let u1 = Unknown(1);
let e1 = Unkn(u1);
let e2 = Const(1.);
@ -563,9 +819,9 @@ mod solve { @@ -563,9 +819,9 @@ mod solve {
let eqn = Eqn(e2.clone(), e1.clone());
assert_eq!(eqn.solve(u1), Some(Const(1.)));
let e3 = Expr::new_plus(Const(1.), Const(1.));
let e3 = Expr::new_sum(Const(1.), Expr::new_sum(Const(1.), Const(2.)));
let eqn = Eqn(e1.clone(), e3.clone());
assert_eq!(eqn.solve(u1), Some(Const(2.)));
assert_eq!(eqn.solve(u1), Some(Const(4.)));
let e3 = Expr::new_minus(Const(1.), Const(1.));
let eqn = Eqn(e1.clone(), e3.clone());
assert_eq!(eqn.solve(u1), Some(Const(0.)));
@ -573,30 +829,74 @@ mod solve { @@ -573,30 +829,74 @@ mod solve {
let e1 = Expr::new_div(Const(2.), Expr::new_minus(Const(1.), Const(4.)));
let e2 = Expr::new_minus(Const(1.), Unkn(u1));
let eqn = Eqn(e1, e2);
info!("eqn: {} => {}", eqn, eqn.clone().simplify());
let e = eqn.solve(u1).unwrap();
assert!(const_expr(e.clone()).is_some());
assert!(relative_eq!(const_expr(e.clone()).unwrap(), 5. / 3.));
let e1 = Expr::new_times(Const(2.), Expr::new_minus(Const(1.), Const(4.)));
let e2 = Expr::new_minus(Expr::new_times(Unkn(u1), Const(2.)), Unkn(u1));
println!(
let e1 = Expr::new_product(Const(2.), Expr::new_minus(Const(1.), Const(4.)));
let e2 = Expr::new_minus(Expr::new_product(Unkn(u1), Const(2.)), Unkn(u1));
info!(
"e1==e2: {}=={} => {}=={}",
e1,
e2,
e1.clone().simplify(),
e2.clone().simplify()
);
println!(
let eqn = Eqn(e1, e2);
let e = eqn.solve(u1).unwrap();
assert!(const_expr(e.clone()).is_some());
assert!(relative_eq!(const_expr(e.clone()).unwrap(), -6.));
let e1 = Expr::new_product(Const(2.), Expr::new_minus(Const(1.), Const(4.)));
let e2 = Expr::new_div(
Expr::new_sum(
Expr::new_product(Unkn(u1), Const(2.)),
Expr::new_product(Unkn(u1), Unkn(u1)),
),
Unkn(u1),
);
info!(
"{}=={} distrib=> {}=={}",
e1,
e2,
e1.clone().distribute(),
e2.clone().distribute()
);
info!(
"{}=={} simplify=> {}=={}",
e1,
e2,
e1.clone().distribute().simplify(),
e2.clone().distribute().simplify()
);
let eqn = Eqn(e1, e2);
let e = eqn.solve(u1).unwrap();
assert!(const_expr(e.clone()).is_some());
assert!(relative_eq!(const_expr(e.clone()).unwrap(), -8.));
let e1 = Expr::new_product(Const(2.), Expr::new_minus(Const(1.), Const(4.)));
let e2 = Expr::new_div(
Expr::new_sum(
Expr::new_product(Unkn(u1), Const(2.)),
Expr::new_sum(
Expr::new_sum(Expr::new_product(Unkn(u1), Unkn(u1)), Unkn(u1)),
Expr::new_minus(Const(2.), Const(1. + 1.)),
),
),
Unkn(u1),
);
info!(
"e1==e2: {}=={} => {}=={}",
e1,
e2,
e1.clone().distrubute(),
e2.clone().distrubute()
e1.clone().distribute().simplify(),
e2.clone().distribute().simplify().simplify()
);
let eqn = Eqn(e1, e2);
let e = eqn.solve(u1).unwrap();
assert!(const_expr(e.clone()).is_some());
assert!(relative_eq!(const_expr(e.clone()).unwrap(), -6.));
assert!(relative_eq!(const_expr(e.clone()).unwrap(), -9.));
}
}

Loading…
Cancel
Save