Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import socket
- import time
- from queue import PriorityQueue
- import struct
- from udp import UDPBasedProtocol
- import constants as const
- headers_len = 16 # 8(ack) + 8(seq)
- ack_timeout = 0.01
- mss = 1500
- window_size = mss * 12
- ack_crit_lag = 20
- class UDPBasedProtocol:
- def __init__(self, *, local_addr, remote_addr):
- self.udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
- self.remote_addr = remote_addr
- self.udp_socket.bind(local_addr)
- def sendto(self, data: bytes):
- return self.udp_socket.sendto(data, self.remote_addr)
- def recvfrom(self, n):
- msg, addr = self.udp_socket.recvfrom(n)
- return msg
- def close(self):
- self.udp_socket.close()
- class Packet:
- def __init__(self, seq_number: int, ack_number: int, data: bytes):
- self.seq_number = seq_number
- self.ack_number = ack_number
- self.data = data
- self.ack_flag = False
- self._sending_time = time.time()
- def dump(self) -> bytes:
- return struct.pack("!QQ", self.seq_number, self.ack_number) + self.data
- @staticmethod
- def parse(raw_bytes: bytes):
- seq, ack = struct.unpack("!QQ", raw_bytes[:headers_len])
- return Packet(seq, ack, raw_bytes[headers_len:])
- def update_sending_time(self, sending_time=None):
- self._sending_time = sending_time if sending_time is not None else time.time()
- @property
- def expired(self):
- return not self.ack_flag and (time.time() - self._sending_time > ack_timeout)
- def __len__(self):
- return len(self.data)
- def __lt__(self, other: 'Packet'):
- return self.seq_number < other.seq_number
- class MyTCPProtocol(UDPBasedProtocol):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._sent_bytes_cnt = 0
- self._confirmed_bytes_cnt = 0
- self._received_bytes_cnt = 0
- self._send_window = PriorityQueue()
- self._recv_window = PriorityQueue()
- self._buffer = bytes()
- def send(self, data: bytes) -> int:
- res_len = 0
- packet_lag = 0
- while (data or self._confirmed_bytes_cnt < self._sent_bytes_cnt) and (packet_lag < ack_crit_lag):
- if self._sent_bytes_cnt - self._confirmed_bytes_cnt <= window_size and data:
- right_border = min(mss, len(data))
- sent_len = self._send_segment(Packet(self._sent_bytes_cnt,
- self._received_bytes_cnt,
- data[:right_border]))
- data = data[sent_len:]
- res_len += sent_len
- self._receive_segment(0.)
- else:
- if self._receive_segment(ack_timeout):
- packet_lag = 0
- else:
- packet_lag += 1
- self._resend_earliest_segment_if_need()
- return res_len
- def recv(self, n: int) -> bytes:
- right_border = min(n, len(self._buffer))
- data = self._buffer[:right_border]
- self._buffer = self._buffer[right_border:]
- while len(data) < n:
- self._receive_segment()
- right_border = min(n, len(self._buffer))
- data += self._buffer[:right_border]
- self._buffer = self._buffer[right_border:]
- return data
- def _receive_segment(self, timeout: float = None) -> bool:
- self.udp_socket.settimeout(timeout)
- try:
- segment = Packet.parse(self.recvfrom(mss + headers_len))
- except socket.error:
- return False
- if len(segment):
- self._recv_window.put((segment.seq_number, segment), block=False)
- self._shift_recv_window()
- if segment.ack_number > self._confirmed_bytes_cnt:
- self._confirmed_bytes_cnt = segment.ack_number
- self._shift_send_window()
- return True
- def _send_segment(self, segment: Packet) -> int:
- """
- @return: длина отправленных данных
- """
- self.udp_socket.settimeout(None)
- just_sent = self.sendto(segment.dump()) - headers_len
- if segment.seq_number == self._sent_bytes_cnt:
- self._sent_bytes_cnt += just_sent
- elif segment.seq_number > self._sent_bytes_cnt:
- raise ValueError(f'Seq number {segment.seq_number} is bigger than {self._sent_bytes_cnt}')
- if len(segment):
- segment.data = segment.data[:just_sent]
- segment.update_sending_time()
- self._send_window.put((segment.seq_number, segment), block=False)
- return just_sent
- def _shift_recv_window(self):
- earliest_segment = None
- while not self._recv_window.empty():
- _, earliest_segment = self._recv_window.get(block=False)
- if earliest_segment.seq_number < self._received_bytes_cnt:
- pass
- elif earliest_segment.seq_number == self._received_bytes_cnt:
- self._buffer += earliest_segment.data
- self._received_bytes_cnt += len(earliest_segment)
- earliest_segment.ack_flag = True
- else:
- self._recv_window.put((earliest_segment.seq_number, earliest_segment), block=False)
- break
- if earliest_segment is not None:
- self._send_segment(Packet(self._sent_bytes_cnt, self._received_bytes_cnt, bytes()))
- def _shift_send_window(self):
- while not self._send_window.empty():
- _, earliest_segment = self._send_window.get(block=False)
- if earliest_segment.seq_number >= self._confirmed_bytes_cnt:
- self._send_window.put((earliest_segment.seq_number, earliest_segment), block=False)
- break
- def _resend_earliest_segment_if_need(self):
- if self._send_window.empty():
- return
- _, earliest_segment = self._send_window.get(block=False)
- if earliest_segment.expired:
- self._send_segment(earliest_segment)
- else:
- self._send_window.put((earliest_segment.seq_number, earliest_segment), block=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement