Advertisement
dream_4ild

Untitled

Nov 12th, 2024
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.24 KB | None | 0 0
  1. import socket
  2. import time
  3. from queue import PriorityQueue
  4. import struct
  5.  
  6. from udp import UDPBasedProtocol
  7. import constants as const
  8.  
  9.  
  10.  
  11. headers_len = 16 # 8(ack) + 8(seq)
  12. ack_timeout = 0.01
  13.  
  14. mss = 1500
  15. window_size = mss * 12
  16. ack_crit_lag = 20
  17.  
  18. class UDPBasedProtocol:
  19.     def __init__(self, *, local_addr, remote_addr):
  20.         self.udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
  21.         self.remote_addr = remote_addr
  22.         self.udp_socket.bind(local_addr)
  23.        
  24.  
  25.     def sendto(self, data: bytes):
  26.         return self.udp_socket.sendto(data, self.remote_addr)
  27.  
  28.     def recvfrom(self, n):
  29.         msg, addr = self.udp_socket.recvfrom(n)
  30.         return msg
  31.  
  32.     def close(self):
  33.         self.udp_socket.close()
  34.  
  35.  
  36. class Packet:
  37.     def __init__(self, seq_number: int, ack_number: int, data: bytes):
  38.         self.seq_number = seq_number
  39.         self.ack_number = ack_number
  40.         self.data = data
  41.         self.ack_flag = False
  42.         self._sending_time = time.time()
  43.  
  44.     def dump(self) -> bytes:
  45.         return struct.pack("!QQ", self.seq_number, self.ack_number) + self.data
  46.  
  47.     @staticmethod
  48.     def parse(raw_bytes: bytes):
  49.         seq, ack = struct.unpack("!QQ", raw_bytes[:headers_len])
  50.         return Packet(seq, ack, raw_bytes[headers_len:])
  51.  
  52.     def update_sending_time(self, sending_time=None):
  53.         self._sending_time = sending_time if sending_time is not None else time.time()
  54.  
  55.     @property
  56.     def expired(self):
  57.         return not self.ack_flag and (time.time() - self._sending_time > ack_timeout)
  58.  
  59.     def __len__(self):
  60.         return len(self.data)
  61.  
  62.     def __lt__(self, other: 'Packet'):
  63.         return self.seq_number < other.seq_number
  64.  
  65.  
  66. class MyTCPProtocol(UDPBasedProtocol):
  67.     def __init__(self, *args, **kwargs):
  68.         super().__init__(*args, **kwargs)
  69.  
  70.         self._sent_bytes_cnt = 0
  71.         self._confirmed_bytes_cnt = 0
  72.         self._received_bytes_cnt = 0
  73.  
  74.         self._send_window = PriorityQueue()
  75.         self._recv_window = PriorityQueue()
  76.  
  77.         self._buffer = bytes()
  78.  
  79.     def send(self, data: bytes) -> int:
  80.         res_len = 0
  81.         packet_lag = 0
  82.  
  83.         while (data or self._confirmed_bytes_cnt < self._sent_bytes_cnt) and (packet_lag < ack_crit_lag):
  84.             if self._sent_bytes_cnt - self._confirmed_bytes_cnt <= window_size and data:
  85.                 right_border = min(mss, len(data))
  86.                 sent_len = self._send_segment(Packet(self._sent_bytes_cnt,
  87.                                                             self._received_bytes_cnt,
  88.                                                             data[:right_border]))
  89.                 data = data[sent_len:]
  90.                 res_len += sent_len
  91.                 self._receive_segment(0.)
  92.             else:
  93.                 if self._receive_segment(ack_timeout):
  94.                     packet_lag = 0
  95.                 else:
  96.                     packet_lag += 1
  97.             self._resend_earliest_segment_if_need()
  98.  
  99.         return res_len
  100.  
  101.     def recv(self, n: int) -> bytes:
  102.         right_border = min(n, len(self._buffer))
  103.         data = self._buffer[:right_border]
  104.         self._buffer = self._buffer[right_border:]
  105.  
  106.         while len(data) < n:
  107.             self._receive_segment()
  108.             right_border = min(n, len(self._buffer))
  109.             data += self._buffer[:right_border]
  110.             self._buffer = self._buffer[right_border:]
  111.  
  112.         return data
  113.  
  114.     def _receive_segment(self, timeout: float = None) -> bool:
  115.         self.udp_socket.settimeout(timeout)
  116.         try:
  117.             segment = Packet.parse(self.recvfrom(mss + headers_len))
  118.         except socket.error:
  119.             return False
  120.  
  121.         if len(segment):
  122.             self._recv_window.put((segment.seq_number, segment), block=False)
  123.             self._shift_recv_window()
  124.  
  125.         if segment.ack_number > self._confirmed_bytes_cnt:
  126.             self._confirmed_bytes_cnt = segment.ack_number
  127.             self._shift_send_window()
  128.  
  129.         return True
  130.  
  131.     def _send_segment(self, segment: Packet) -> int:
  132.         """
  133.        @return: длина отправленных данных
  134.        """
  135.         self.udp_socket.settimeout(None)
  136.         just_sent = self.sendto(segment.dump()) - headers_len
  137.  
  138.         if segment.seq_number == self._sent_bytes_cnt:
  139.             self._sent_bytes_cnt += just_sent
  140.         elif segment.seq_number > self._sent_bytes_cnt:
  141.             raise ValueError(f'Seq number {segment.seq_number} is bigger than {self._sent_bytes_cnt}')
  142.  
  143.         if len(segment):
  144.             segment.data = segment.data[:just_sent]
  145.             segment.update_sending_time()
  146.             self._send_window.put((segment.seq_number, segment), block=False)
  147.  
  148.         return just_sent
  149.  
  150.     def _shift_recv_window(self):
  151.         earliest_segment = None
  152.         while not self._recv_window.empty():
  153.             _, earliest_segment = self._recv_window.get(block=False)
  154.             if earliest_segment.seq_number < self._received_bytes_cnt:
  155.                 pass
  156.             elif earliest_segment.seq_number == self._received_bytes_cnt:
  157.                 self._buffer += earliest_segment.data
  158.                 self._received_bytes_cnt += len(earliest_segment)
  159.                 earliest_segment.ack_flag = True
  160.             else:
  161.                 self._recv_window.put((earliest_segment.seq_number, earliest_segment), block=False)
  162.                 break
  163.  
  164.         if earliest_segment is not None:
  165.             self._send_segment(Packet(self._sent_bytes_cnt, self._received_bytes_cnt, bytes()))
  166.  
  167.     def _shift_send_window(self):
  168.         while not self._send_window.empty():
  169.             _, earliest_segment = self._send_window.get(block=False)
  170.             if earliest_segment.seq_number >= self._confirmed_bytes_cnt:
  171.                 self._send_window.put((earliest_segment.seq_number, earliest_segment), block=False)
  172.                 break
  173.  
  174.     def _resend_earliest_segment_if_need(self):
  175.         if self._send_window.empty():
  176.             return
  177.         _, earliest_segment = self._send_window.get(block=False)
  178.         if earliest_segment.expired:
  179.             self._send_segment(earliest_segment)
  180.         else:
  181.             self._send_window.put((earliest_segment.seq_number, earliest_segment), block=False)
  182.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement