summaryrefslogblamecommitdiff
path: root/src/requests/start.rs
blob: bc6330389aa0f167ca3df21d3eecf14c6208e214 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
                                       
                                            
                        
                       
                                                             
 



                                                                                                   
 
                                                                                     
                                          
                                                                                  
                                                   
                                                            
                                                                                    
                                    
                                                                 
               
              
                                     
                                                                          
 
                                                                                      

                                                             
                                                              


                                       
                                                                               

                     
                                                                                                 
                           



                                                                    
                        



                                                                    
                 
             
         
                                                             
                                                              
                                       
 
                                                                                 
         



          





                             
                                                                              


                                                                          
                                                                       
 
                                                               
 


                                                                                                  
 
                                                                        
                                              
                                                                  
 
                                         
 




                                                                            
         












                                                                           

     
                                                                

                                              

                                                                 


                                                                                 



                                       
                                                             
 


                             
                 
 

                               
              
















                                                     
use futures_util::{SinkExt, StreamExt};
use indicatif::{MultiProgress, ProgressBar};
use reqwest::StatusCode;
use serde::Deserialize;
use tokio_tungstenite::{connect_async, tungstenite::Message};

use crate::{
    add_pb, config::Config, default_headers, error::CliError, finish_pb, format_url, ErrorResponse,
    Protocols, DEFAULT_STYLE, DONE_STYLE, ERROR_STYLE, OVERVIEW_DONE, OVERVIEW_ERROR,
    OVERVIEW_STYLE,
};

pub async fn start(config: &Config, id: String, ping: bool) -> Result<(), CliError> {
    let send_start = MultiProgress::new();
    let overview = add_pb(&send_start, OVERVIEW_STYLE, format!(") start {}", id));

    // TODO: calculate average start-time on server
    let url = format_url(config, "start", Protocols::Http)?;
    let connect = add_pb(&send_start, DEFAULT_STYLE, format!("connect to {}", url));
    let res = reqwest::Client::new()
        .post(url)
        .headers(default_headers(config)?)
        .body(format!(r#"{{"id": "{}", "ping": {}}}"#, id, ping))
        .send()
        .await
        .map_err(CliError::Reqwest)?;
    finish_pb(connect, "connected, got response".to_string(), DONE_STYLE);

    let res_pb = add_pb(&send_start, DEFAULT_STYLE, "analyzing response".to_string());
    match res.status() {
        StatusCode::OK => {
            let body = serde_json::from_str::<StartResponse>(
                &res.text().await.map_err(CliError::Reqwest)?,
            )
            .map_err(CliError::Serde)?;

            if body.boot {
                finish_pb(res_pb, "sent start packet".to_string(), DONE_STYLE);
            }

            if ping {
                let status = status_socket(config, body.uuid, &send_start, &overview, id).await?;
                if status {
                    finish_pb(
                        overview,
                        format!("successfully started {}", body.id),
                        OVERVIEW_DONE,
                    );
                } else {
                    finish_pb(
                        overview,
                        format!("error while starting {}", body.id),
                        OVERVIEW_ERROR,
                    );
                }
            }
        }
        _ => {
            let body = serde_json::from_str::<ErrorResponse>(
                &res.text().await.map_err(CliError::Reqwest)?,
            )
            .map_err(CliError::Serde)?;

            res_pb.finish_with_message(format!("✗ got error: {}", body.error));
        }
    }

    Ok(())
}

async fn status_socket(
    config: &Config,
    uuid: String,
    pb: &MultiProgress,
    overview: &ProgressBar,
    id: String,
) -> Result<bool, CliError> {
    let ws_pb = add_pb(pb, DEFAULT_STYLE, "connect to websocket".to_string());
    let (mut ws_stream, _response) =
        connect_async(format_url(config, "status", Protocols::Websocket)?)
            .await
            .expect("Failed to connect");
    finish_pb(ws_pb, "connected to websocket".to_string(), DONE_STYLE);

    ws_stream.send(Message::Text(uuid.clone())).await.unwrap();

    // Get ETA
    let eta_msg = ws_stream.next().await.unwrap().unwrap();
    let eta = get_eta(eta_msg.into_text().unwrap(), uuid.clone())? + overview.elapsed().as_secs();
    overview.set_message(format!("/{}) start {}", eta, id));

    let msg_pb = add_pb(pb, DEFAULT_STYLE, "await message".to_string());
    let msg = ws_stream.next().await.unwrap();
    finish_pb(msg_pb, "received message".to_string(), DONE_STYLE);

    ws_stream.close(None).await.unwrap();

    let v_pb = add_pb(pb, DEFAULT_STYLE, "verify response".to_string());
    let res = verify_response(msg.unwrap().to_string(), uuid)?;
    match res {
        Verified::WrongUuid => {
            finish_pb(v_pb, "returned wrong uuid".to_string(), ERROR_STYLE);
            Ok(false)
        }
        Verified::ResponseType(res_type) => match res_type {
            ResponseType::Start => {
                finish_pb(v_pb, "device started".to_string(), DONE_STYLE);
                Ok(true)
            }
            ResponseType::Timeout => {
                finish_pb(v_pb, "ping timed out".to_string(), ERROR_STYLE);
                Ok(false)
            }
            ResponseType::NotFound => {
                finish_pb(v_pb, "unknown uuid".to_string(), ERROR_STYLE);
                Ok(false)
            }
        },
    }
}

fn get_eta(msg: String, uuid: String) -> Result<u64, CliError> {
    let spl: Vec<&str> = msg.split('_').collect();
    if (spl[0] != "eta") || (spl[2] != uuid) {
        return Err(CliError::WsResponse);
    };
    Ok(u64::from_str_radix(spl[1], 10).map_err(CliError::Parse)?)
}

fn verify_response(res: String, org_uuid: String) -> Result<Verified, CliError> {
    let spl: Vec<&str> = res.split('_').collect();
    let res_type = spl[0];
    let uuid = spl[1];

    if uuid != org_uuid {
        return Ok(Verified::WrongUuid);
    };

    Ok(Verified::ResponseType(ResponseType::from(res_type)?))
}

#[derive(Debug, Deserialize)]
struct StartResponse {
    boot: bool,
    id: String,
    uuid: String,
}

enum Verified {
    ResponseType(ResponseType),
    WrongUuid,
}

enum ResponseType {
    Start,
    Timeout,
    NotFound,
}

impl ResponseType {
    fn from(value: &str) -> Result<Self, CliError> {
        match value {
            "start" => Ok(ResponseType::Start),
            "timeout" => Ok(ResponseType::Timeout),
            "notfound" => Ok(ResponseType::NotFound),
            _ => Err(CliError::WsResponse),
        }
    }
}