Browse Source

Implement cancelling by section from MQTT

master
Alex Mikhalev 4 years ago
parent
commit
b72d89e1dd
  1. 74
      src/mqtt/request/mod.rs
  2. 43
      src/mqtt/request/run_section.rs
  3. 87
      src/mqtt/request/sections.rs
  4. 89
      src/section_runner.rs

74
src/mqtt/request/mod.rs

@ -1,11 +1,12 @@ @@ -1,11 +1,12 @@
use crate::{model::Sections, section_runner::SectionRunner};
use futures_util::ready;
use futures_util::FutureExt;
use num_derive::FromPrimitive;
use serde::{Deserialize, Serialize};
use std::{fmt, future::Future, pin::Pin};
use std::{fmt, future::Future, pin::Pin, task::Poll};
mod run_section;
mod sections;
pub struct RequestContext {
pub sections: Sections,
@ -14,8 +15,6 @@ pub struct RequestContext { @@ -14,8 +15,6 @@ pub struct RequestContext {
type BoxFuture<Output> = Pin<Box<dyn Future<Output = Output>>>;
pub type ResponseValue = serde_json::Value;
#[derive(Copy, Clone, Debug, PartialEq, Eq, FromPrimitive)]
#[repr(u16)]
pub enum ErrorCode {
@ -156,17 +155,69 @@ impl RequestError { @@ -156,17 +155,69 @@ impl RequestError {
}
}
type RequestResult = Result<ResponseValue, RequestError>;
type RequestFuture = BoxFuture<RequestResult>;
#[derive(Debug, Serialize, Deserialize)]
struct ResponseMessage(#[serde(rename = "message")] String);
impl ResponseMessage {
fn new<M>(message: M) -> Self
where
M: ToString,
{
ResponseMessage(message.to_string())
}
}
impl From<String> for ResponseMessage {
fn from(message: String) -> Self {
ResponseMessage(message)
}
}
pub type ResponseValue = serde_json::Value;
type RequestResult<Ok = ResponseValue> = Result<Ok, RequestError>;
type RequestFuture<Ok = ResponseValue> = BoxFuture<RequestResult<Ok>>;
trait IRequest {
fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture;
type Response: Serialize;
fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture<Self::Response>;
fn exec_erased(&mut self, ctx: &mut RequestContext) -> RequestFuture
where
Self::Response: 'static,
{
// TODO: figure out how to get rid of this nested box
Box::pin(ErasedRequestFuture(self.exec(ctx)))
}
}
struct ErasedRequestFuture<Response>(RequestFuture<Response>)
where
Response: Serialize;
impl<Response> Future for ErasedRequestFuture<Response>
where
Response: Serialize,
{
type Output = RequestResult;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
use eyre::WrapErr;
let response = ready!(self.as_mut().0.poll_unpin(cx));
Poll::Ready(response.and_then(|res| {
serde_json::to_value(res)
.wrap_err("could not serialize response")
.map_err(RequestError::from)
}))
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase", tag = "type")]
pub enum Request {
RunSection(run_section::RequestData),
RunSection(sections::RunSectionRequest),
CancelSection(sections::CancelSectionRequest),
}
#[derive(Debug, Deserialize, Serialize)]
@ -192,9 +243,12 @@ impl From<RequestError> for Response { @@ -192,9 +243,12 @@ impl From<RequestError> for Response {
}
impl IRequest for Request {
fn exec(&mut self, ctx: &mut RequestContext) -> BoxFuture<RequestResult> {
type Response = ResponseValue;
fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture {
match self {
Request::RunSection(req) => req.exec(ctx),
Request::RunSection(req) => req.exec_erased(ctx),
Request::CancelSection(req) => req.exec_erased(ctx),
}
}
}

43
src/mqtt/request/run_section.rs

@ -1,43 +0,0 @@ @@ -1,43 +0,0 @@
use super::*;
use crate::{model::SectionId, section_runner::SectionRunHandle};
use eyre::WrapErr;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RequestData {
pub section_id: SectionId,
#[serde(with = "crate::serde::duration")]
pub duration: Duration,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResponseData {
pub message: String,
pub run_id: SectionRunHandle,
}
impl IRequest for RequestData {
fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture {
let mut section_runner = ctx.section_runner.clone();
let section = ctx.sections.get(&self.section_id).cloned();
let duration = self.duration;
Box::pin(async move {
let section = section.ok_or_else(|| {
RequestError::with_name(ErrorCode::NotFound, "section not found", "section")
})?;
let handle = section_runner
.queue_run(section.clone(), duration)
.await
.wrap_err("could not queue run")?;
let res = ResponseData {
message: format!("running section '{}' for {:?}", &section.name, duration),
run_id: handle,
};
let res_value = serde_json::to_value(res).wrap_err("could not serialize response")?;
Ok(res_value)
})
}
}

87
src/mqtt/request/sections.rs

@ -0,0 +1,87 @@ @@ -0,0 +1,87 @@
use super::*;
use crate::{model::SectionRef, section_runner::SectionRunHandle};
use eyre::WrapErr;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
#[serde(transparent)]
pub struct SectionId(pub crate::model::SectionId);
impl SectionId {
fn get_section(self, sections: &Sections) -> Result<SectionRef, RequestError> {
sections.get(&self.0).cloned().ok_or_else(|| {
RequestError::with_name(ErrorCode::NotFound, "section not found", "section")
})
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RunSectionRequest {
pub section_id: SectionId,
#[serde(with = "crate::serde::duration")]
pub duration: Duration,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RunSectionResponse {
pub message: String,
pub run_id: SectionRunHandle,
}
impl IRequest for RunSectionRequest {
type Response = RunSectionResponse;
fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture<Self::Response> {
let mut section_runner = ctx.section_runner.clone();
let section = self.section_id.get_section(&ctx.sections);
let duration = self.duration;
Box::pin(async move {
let section = section?;
let handle = section_runner
.queue_run(section.clone(), duration)
.await
.wrap_err("could not queue run")?;
Ok(RunSectionResponse {
message: format!("running section '{}' for {:?}", &section.name, duration),
run_id: handle,
})
})
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CancelSectionRequest {
pub section_id: SectionId,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CancelSectionResponse {
pub message: String,
pub cancelled: usize,
}
impl IRequest for CancelSectionRequest {
type Response = CancelSectionResponse;
fn exec(&mut self, ctx: &mut RequestContext) -> RequestFuture<Self::Response> {
let mut section_runner = ctx.section_runner.clone();
let section = self.section_id.get_section(&ctx.sections);
Box::pin(async move {
let section = section?;
let cancelled = section_runner
.cancel_by_section(section.id)
.await
.wrap_err("could not cancel section")?;
Ok(CancelSectionResponse {
message: format!(
"cancelled {} runs for section '{}'",
cancelled, section.name
),
cancelled,
})
})
}
}

89
src/section_runner.rs

@ -1,4 +1,4 @@ @@ -1,4 +1,4 @@
use crate::model::SectionRef;
use crate::model::{SectionId, SectionRef};
use crate::section_interface::SectionInterface;
use actix::{
Actor, ActorContext, Addr, AsyncContext, Handler, Message, MessageResult, SpawnHandle,
@ -20,8 +20,7 @@ use tokio::{ @@ -20,8 +20,7 @@ use tokio::{
};
use tracing::{debug, trace, warn};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(serde::Deserialize, serde::Serialize)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
pub struct SectionRunHandle(i32);
impl SectionRunHandle {
@ -177,19 +176,24 @@ impl SectionRunnerInner { @@ -177,19 +176,24 @@ impl SectionRunnerInner {
}
}
fn cancel_run(&mut self, run: &mut Arc<SecRun>) {
fn cancel_run(&mut self, run: &mut Arc<SecRun>) -> bool {
let run = Arc::make_mut(run);
if run.is_running() {
debug!(section_id = run.section.id, "cancelling running section");
self.interface
.set_section_state(run.section.interface_id, false);
}
if run.state != SecRunState::Cancelled {
run.state = SecRunState::Cancelled;
self.send_event(SectionEvent::RunCancel(
run.handle.clone(),
run.section.clone(),
));
self.did_change = true;
true
} else {
false
}
}
fn pause_run(&mut self, run: &mut Arc<SecRun>) {
@ -322,40 +326,78 @@ impl Handler<QueueRun> for SectionRunnerActor { @@ -322,40 +326,78 @@ impl Handler<QueueRun> for SectionRunnerActor {
}
#[derive(Message, Debug, Clone)]
#[rtype(result = "()")]
#[rtype(result = "bool")]
struct CancelRun(SectionRunHandle);
impl Handler<CancelRun> for SectionRunnerActor {
type Result = ();
type Result = bool;
fn handle(&mut self, msg: CancelRun, ctx: &mut Self::Context) -> Self::Result {
let CancelRun(handle) = msg;
for run in self.state.run_queue.iter_mut() {
if run.handle != handle {
continue;
}
let mut cancelled = false;
for run in self
.state
.run_queue
.iter_mut()
.filter(|run| run.handle == handle)
{
trace!(handle = handle.0, "cancelling run by handle");
self.inner.cancel_run(run);
cancelled = self.inner.cancel_run(run);
}
ctx.notify(Process);
cancelled
}
}
#[derive(Message, Debug, Clone)]
#[rtype(result = "()")]
#[rtype(result = "usize")]
struct CancelBySection(SectionId);
impl Handler<CancelBySection> for SectionRunnerActor {
type Result = usize;
fn handle(&mut self, msg: CancelBySection, ctx: &mut Self::Context) -> Self::Result {
let CancelBySection(section_id) = msg;
let mut count = 0_usize;
for run in self
.state
.run_queue
.iter_mut()
.filter(|run| run.section.id == section_id)
{
trace!(
handle = run.handle.0,
section_id,
"cancelling run by section"
);
if self.inner.cancel_run(run) {
count += 1;
}
}
ctx.notify(Process);
count
}
}
#[derive(Message, Debug, Clone)]
#[rtype(result = "usize")]
struct CancelAll;
impl Handler<CancelAll> for SectionRunnerActor {
type Result = ();
type Result = usize;
fn handle(&mut self, _msg: CancelAll, ctx: &mut Self::Context) -> Self::Result {
let mut old_runs = SecRunQueue::new();
swap(&mut old_runs, &mut self.state.run_queue);
trace!(count = old_runs.len(), "cancelling all runs");
let mut count = 0usize;
for mut run in old_runs {
self.inner.cancel_run(&mut run);
if self.inner.cancel_run(&mut run) {
count += 1;
}
}
ctx.notify(Process);
count
}
}
@ -517,11 +559,7 @@ impl SectionRunner { @@ -517,11 +559,7 @@ impl SectionRunner {
(QueueRun(handle.clone(), section, duration), handle)
}
pub fn do_queue_run(
&mut self,
section: SectionRef,
duration: Duration,
) -> SectionRunHandle {
pub fn do_queue_run(&mut self, section: SectionRef, duration: Duration) -> SectionRunHandle {
let (queue_run, handle) = self.queue_run_inner(section, duration);
self.addr.do_send(queue_run);
handle
@ -543,11 +581,20 @@ impl SectionRunner { @@ -543,11 +581,20 @@ impl SectionRunner {
self.addr.do_send(CancelRun(handle))
}
pub fn cancel_run(&mut self, handle: SectionRunHandle) -> impl Future<Output = Result<()>> {
pub fn cancel_run(&mut self, handle: SectionRunHandle) -> impl Future<Output = Result<bool>> {
self.addr.send(CancelRun(handle)).map_err(From::from)
}
pub fn cancel_all(&mut self) -> impl Future<Output = Result<()>> {
pub fn cancel_by_section(
&mut self,
section_id: SectionId,
) -> impl Future<Output = Result<usize>> {
self.addr
.send(CancelBySection(section_id))
.map_err(From::from)
}
pub fn cancel_all(&mut self) -> impl Future<Output = Result<usize>> {
self.addr.send(CancelAll).map_err(From::from)
}

Loading…
Cancel
Save