Advertisement
BilakshanP

main.rs - 2

Oct 16th, 2024
143
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 14.42 KB | Source Code | 0 0
  1. use aes_gcm::aead::{Aead, NewAead};
  2. use aes_gcm::{Aes256Gcm, Key, Nonce};
  3. use rand::Rng;
  4. use rand_core::OsRng;
  5. use std::fs;
  6. use std::hash::{DefaultHasher, Hash, Hasher};
  7. use std::io::{self, Read, Write};
  8. use std::net::{SocketAddr, TcpListener, TcpStream};
  9. use std::path::Path;
  10. use std::sync::{Arc, Mutex};
  11. use std::thread;
  12. use std::time::{SystemTime, UNIX_EPOCH};
  13. use x25519_dalek::{EphemeralSecret, PublicKey};
  14.  
  15. mod config;
  16.  
  17. use config::CONFIG;
  18.  
  19. enum Mode {
  20.     Server,
  21.     Client,
  22. }
  23.  
  24. enum MessageType {
  25.     Text(String),
  26.     File { name: String, content: Vec<u8> },
  27. }
  28.  
  29. struct ClientInfo {
  30.     stream: TcpStream,
  31.     id: String,
  32.     public_key: PublicKey,
  33.     shared_secret: Option<[u8; 32]>,
  34. }
  35.  
  36. fn generate_id() -> String {
  37.     let now = SystemTime::now()
  38.         .duration_since(UNIX_EPOCH)
  39.         .unwrap()
  40.         .as_millis();
  41.  
  42.     let mut hasher = DefaultHasher::new();
  43.  
  44.     now.hash(&mut hasher);
  45.  
  46.     format!("{:04x}", hasher.finish() % 65536)
  47. }
  48.  
  49. fn save_key(key: &[u8], prefix: &str, id: &str) {
  50.     if CONFIG.debug {
  51.         let filename = format!("{}-key#{}.pem", prefix, id);
  52.         fs::write(&filename, key).unwrap();
  53.         println!("Debug: Saved {} to {}", prefix, filename);
  54.     }
  55. }
  56.  
  57. fn save_file(content: &[u8], filename: &str) {
  58.     if CONFIG.debug {
  59.         fs::write(filename, content).unwrap();
  60.         println!("Debug: Saved unencrypted file to {}", filename);
  61.     }
  62. }
  63.  
  64. // fn debug_save_ephemeral_secret(secret: &EphemeralSecret, prefix: &str, id: &str) {
  65. //     if CONFIG.debug {
  66. //         let public_key = PublicKey::from(secret);
  67. //         let filename = format!("{}-key#{}.pem", prefix, id);
  68. //         fs::write(&filename, public_key.as_bytes()).unwrap();
  69. //         println!("Debug: Saved {} public key to {}", prefix, filename);
  70. //     }
  71. // }
  72.  
  73. fn debug_print(message: &str, data: &[u8]) {
  74.     if CONFIG.debug {
  75.         println!("Debug: {} - {:?}", message, data);
  76.         let mut hasher = DefaultHasher::new();
  77.         data.hash(&mut hasher);
  78.         println!("Debug: {} hash - {:x}", message, hasher.finish());
  79.     }
  80. }
  81.  
  82. fn encrypt_message(message: &str, shared_secret: &[u8; 32]) -> Vec<u8> {
  83.     let key = Key::from_slice(shared_secret);
  84.     let cipher = Aes256Gcm::new(key);
  85.     let pre_nonce = rand::thread_rng().gen::<[u8; 12]>();
  86.     let nonce = Nonce::from_slice(&pre_nonce);
  87.     let ciphertext = cipher.encrypt(nonce, message.as_bytes()).unwrap();
  88.  
  89.     debug_print("Unencrypted message", message.as_bytes());
  90.     debug_print("Encrypted message", &ciphertext);
  91.     [nonce.to_vec(), ciphertext].concat()
  92. }
  93.  
  94. fn decrypt_message(encrypted: &[u8], shared_secret: &[u8; 32]) -> String {
  95.     let key = Key::from_slice(shared_secret);
  96.     let cipher = Aes256Gcm::new(key);
  97.     let nonce = Nonce::from_slice(&encrypted[..12]);
  98.     let ciphertext = &encrypted[12..];
  99.     let plaintext = cipher.decrypt(nonce, ciphertext).unwrap();
  100.     debug_print("Encrypted message", encrypted);
  101.     debug_print("Decrypted message", &plaintext);
  102.     String::from_utf8(plaintext).unwrap()
  103. }
  104.  
  105. fn handle_connection(
  106.     mut client_info: ClientInfo,
  107.     clients: Arc<Mutex<Vec<ClientInfo>>>,
  108.     server_addr: SocketAddr,
  109. ) {
  110.     let mut buffer = [0; 1024 * 1024];
  111.     let client_addr = client_info.stream.peer_addr().unwrap();
  112.     let client_id = client_info.id.clone();
  113.  
  114.     if CONFIG.debug {
  115.         save_key(
  116.             client_info.public_key.as_bytes(),
  117.             "client-public",
  118.             &client_id,
  119.         );
  120.         if let Some(secret) = &client_info.shared_secret {
  121.             save_key(secret, "shared-secret", &client_id);
  122.         }
  123.     }
  124.  
  125.     loop {
  126.         match client_info.stream.read(&mut buffer) {
  127.             Ok(0) => break,
  128.             Ok(bytes_read) => {
  129.                 let encrypted_message = &buffer[..bytes_read];
  130.                 if let Some(shared_secret) = client_info.shared_secret {
  131.                     let decrypted_message = decrypt_message(encrypted_message, &shared_secret);
  132.  
  133.                     let message_type = if decrypted_message.starts_with("FILE:") {
  134.                         let parts: Vec<&str> = decrypted_message
  135.                             .strip_prefix("FILE:")
  136.                             .unwrap()
  137.                             .splitn(2, ':')
  138.                             .collect();
  139.                         if parts.len() == 2 {
  140.                             if CONFIG.debug {
  141.                                 save_file(parts[1].as_bytes(), &format!("debug_{}", parts[0]));
  142.                             }
  143.                             MessageType::File {
  144.                                 name: parts[0].to_string(),
  145.                                 content: parts[1].as_bytes().to_vec(),
  146.                             }
  147.                         } else {
  148.                             MessageType::Text(decrypted_message)
  149.                         }
  150.                     } else {
  151.                         MessageType::Text(decrypted_message)
  152.                     };
  153.  
  154.                     match &message_type {
  155.                         MessageType::Text(text) => println!("{}", text.trim()),
  156.                         MessageType::File { name, .. } => println!("Received file: {}", name),
  157.                     }
  158.  
  159.                     let mut clients = clients.lock().unwrap();
  160.                     for client in clients.iter_mut() {
  161.                         if client.id != client_id {
  162.                             if let Some(client_shared_secret) = client.shared_secret {
  163.                                 let message_to_send = match &message_type {
  164.                                     MessageType::Text(text) => text.clone(),
  165.                                     MessageType::File { name, content } => {
  166.                                         format!(
  167.                                             "FILE:{}:{}",
  168.                                             name,
  169.                                             String::from_utf8_lossy(content)
  170.                                         )
  171.                                     }
  172.                                 };
  173.                                 let encrypted_for_client =
  174.                                     encrypt_message(&message_to_send, &client_shared_secret);
  175.                                 client.stream.write_all(&encrypted_for_client).unwrap();
  176.                                 client.stream.flush().unwrap();
  177.                             }
  178.                         }
  179.                     }
  180.                 }
  181.             }
  182.             Err(_) => break,
  183.         }
  184.     }
  185.  
  186.     let mut clients = clients.lock().unwrap();
  187.     clients.retain(|client| client.id != client_id);
  188.     println!("Client@{}#{} disconnected", client_addr.ip(), client_id);
  189. }
  190.  
  191. fn run_server(addr: SocketAddr) -> io::Result<()> {
  192.     let listener = TcpListener::bind(addr)?;
  193.     println!("Server listening on {}", addr);
  194.     println!("Your IP address is: {}", addr.ip());
  195.  
  196.     let clients: Arc<Mutex<Vec<ClientInfo>>> = Arc::new(Mutex::new(Vec::new()));
  197.     let clients_clone = Arc::clone(&clients);
  198.  
  199.     // Spawn a thread to handle server input
  200.     let server_addr = addr;
  201.     thread::spawn(move || loop {
  202.         let mut input = String::new();
  203.         io::stdin().read_line(&mut input).unwrap();
  204.         let message = format!("Server@{}: {}\n", server_addr.ip(), input.trim());
  205.  
  206.         let mut clients = clients_clone.lock().unwrap();
  207.         for client in clients.iter_mut() {
  208.             if let Some(shared_secret) = client.shared_secret {
  209.                 let encrypted_message = encrypt_message(&message, &shared_secret);
  210.                 client.stream.write_all(&encrypted_message).unwrap();
  211.                 client.stream.flush().unwrap();
  212.             }
  213.         }
  214.     });
  215.  
  216.     for stream in listener.incoming() {
  217.         match stream {
  218.             Ok(mut stream) => {
  219.                 let client_id = generate_id();
  220.                 println!(
  221.                     "New client connected: {:?}#{}",
  222.                     stream.peer_addr()?,
  223.                     client_id
  224.                 );
  225.  
  226.                 let mut public_key_bytes = [0u8; 32];
  227.                 stream.read_exact(&mut public_key_bytes)?;
  228.                 let client_public_key = PublicKey::from(public_key_bytes);
  229.  
  230.                 let server_secret = EphemeralSecret::random_from_rng(OsRng);
  231.                 let server_public = PublicKey::from(&server_secret);
  232.  
  233.                 stream.write_all(server_public.as_bytes())?;
  234.  
  235.                 let shared_secret = server_secret.diffie_hellman(&client_public_key);
  236.  
  237.                 if CONFIG.debug {
  238.                     save_key(server_public.as_bytes(), "server-public", &client_id);
  239.                     save_key(shared_secret.as_bytes(), "shared-secret", &client_id);
  240.                 }
  241.  
  242.                 let client_info = ClientInfo {
  243.                     stream: stream.try_clone()?,
  244.                     id: client_id.clone(),
  245.                     public_key: client_public_key,
  246.                     shared_secret: Some(*shared_secret.as_bytes()),
  247.                 };
  248.                 let clients = Arc::clone(&clients);
  249.                 clients.lock().unwrap().push(client_info);
  250.  
  251.                 let server_addr = addr;
  252.                 let client_info = ClientInfo {
  253.                     stream,
  254.                     id: client_id,
  255.                     public_key: client_public_key,
  256.                     shared_secret: Some(*shared_secret.as_bytes()),
  257.                 };
  258.                 thread::spawn(move || {
  259.                     handle_connection(client_info, clients, server_addr);
  260.                 });
  261.             }
  262.             Err(e) => {
  263.                 eprintln!("Error accepting client: {}", e);
  264.             }
  265.         }
  266.     }
  267.     Ok(())
  268. }
  269.  
  270. fn run_client(server_addr: SocketAddr) -> io::Result<()> {
  271.     let mut stream = TcpStream::connect(server_addr)?;
  272.     println!("Connected to server at {}", server_addr);
  273.  
  274.     let client_addr = stream.local_addr()?;
  275.     let client_id = generate_id();
  276.     println!("Your address is: {}#{}", client_addr.ip(), client_id);
  277.  
  278.     let client_secret = EphemeralSecret::random_from_rng(OsRng);
  279.     let client_public = PublicKey::from(&client_secret);
  280.  
  281.     stream.write_all(client_public.as_bytes())?;
  282.  
  283.     let mut server_public_key_bytes = [0u8; 32];
  284.     stream.read_exact(&mut server_public_key_bytes)?;
  285.     let server_public_key = PublicKey::from(server_public_key_bytes);
  286.  
  287.     let shared_secret = client_secret.diffie_hellman(&server_public_key);
  288.  
  289.     if CONFIG.debug {
  290.         save_key(client_public.as_bytes(), "client-public", &client_id);
  291.         save_key(server_public_key.as_bytes(), "server-public", &client_id);
  292.         save_key(shared_secret.as_bytes(), "shared-secret", &client_id);
  293.     }
  294.  
  295.     let mut receive_stream = stream.try_clone()?;
  296.  
  297.     let ss_bytes = *shared_secret.as_bytes();
  298.  
  299.     thread::spawn(move || {
  300.         let mut buffer = [0; 1024 * 1024];
  301.         loop {
  302.             match receive_stream.read(&mut buffer) {
  303.                 Ok(0) => break,
  304.                 Ok(bytes_read) => {
  305.                     let encrypted_message = &buffer[..bytes_read];
  306.                     let decrypted_message = decrypt_message(encrypted_message, &ss_bytes);
  307.                     if decrypted_message.starts_with("FILE:") {
  308.                         let parts: Vec<&str> = decrypted_message[5..].splitn(2, ':').collect();
  309.                         if parts.len() == 2 {
  310.                             let filename = parts[0];
  311.                             let content = parts[1].as_bytes();
  312.                             fs::write(filename, content).unwrap();
  313.                             if CONFIG.debug {
  314.                                 save_file(content, &format!("debug_received_{}", filename));
  315.                             }
  316.                             println!("Received file: {}", filename);
  317.                         }
  318.                     } else {
  319.                         print!("{}", decrypted_message);
  320.                     }
  321.                     io::stdout().flush().unwrap();
  322.                 }
  323.                 Err(_) => break,
  324.             }
  325.         }
  326.     });
  327.  
  328.     loop {
  329.         let mut input = String::new();
  330.         io::stdin().read_line(&mut input)?;
  331.         let input = input.trim();
  332.  
  333.         if input.starts_with("/file ") {
  334.             let file_path = input[6..].trim();
  335.             if let Ok(content) = fs::read(file_path) {
  336.                 let file_name = Path::new(file_path).file_name().unwrap().to_str().unwrap();
  337.                 let message = format!("FILE:{}:{}", file_name, String::from_utf8_lossy(&content));
  338.                 if CONFIG.debug {
  339.                     save_file(&content, &format!("debug_sent_{}", file_name));
  340.                 }
  341.                 let encrypted_message = encrypt_message(&message, &ss_bytes);
  342.                 stream.write_all(&encrypted_message)?;
  343.                 stream.flush()?;
  344.                 println!("File sent: {}", file_name);
  345.             } else {
  346.                 println!("Failed to read file: {}", file_path);
  347.             }
  348.         } else {
  349.             let message = format!("Client@{}#{}: {}\n", client_addr.ip(), client_id, input);
  350.             let encrypted_message = encrypt_message(&message, &ss_bytes);
  351.             stream.write_all(&encrypted_message)?;
  352.             stream.flush()?;
  353.         }
  354.     }
  355. }
  356.  
  357. fn get_input(prompt: &str, default: &str) -> String {
  358.     println!("{} (default: {})", prompt, default);
  359.     let mut input = String::new();
  360.     io::stdin().read_line(&mut input).unwrap();
  361.     let input = input.trim();
  362.     if input.is_empty() {
  363.         default.to_string()
  364.     } else {
  365.         input.to_string()
  366.     }
  367. }
  368.  
  369. fn main() -> io::Result<()> {
  370.     println!("Choose mode: (1) Server, (2) Client");
  371.     let mode = match get_input("Enter 1 or 2", "1").as_str() {
  372.         "1" => Mode::Server,
  373.         "2" => Mode::Client,
  374.         _ => {
  375.             println!("Invalid choice. Defaulting to Server mode.");
  376.             Mode::Server
  377.         }
  378.     };
  379.  
  380.     match mode {
  381.         Mode::Server => {
  382.             let ip = get_input("Enter IP to bind to", "0.0.0.0");
  383.             let port = get_input("Enter port to listen on", "8080");
  384.             let addr: SocketAddr = format!("{}:{}", ip, port).parse().unwrap();
  385.             run_server(addr)
  386.         }
  387.         Mode::Client => {
  388.             let ip = get_input("Enter server IP", "127.0.0.1");
  389.             let port = get_input("Enter server port", "8080");
  390.             let addr: SocketAddr = format!("{}:{}", ip, port).parse().unwrap();
  391.             run_client(addr)
  392.         }
  393.     }
  394. }
  395.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement