how to verify tcp checksum

theStig picture theStig · Jan 19, 2013 · Viewed 17.8k times · Source

For some odd reason, i'm unable to properly verify the TCP checksum. I have code to check IP and UDP checksum, and it works perfectly fine, but for TCP something in my logic is amiss.

My struct definitions for these headers are fine as i can read the data perfectly fine (verified from wireshark). The only problem i'm having is that for TCP checksum, i'm unable to verify whether the checksum is actually correct. Any thoughts as to where i'm doing this wrong?

Very much appreciated.

checksum function

unsigned short in_cksum(unsigned short *addr,int len)
{
    register int sum = 0;
    u_short answer = 0;
    register u_short *w = addr;
    register int nleft = len;

    /*
     * Our algorithm is simple, using a 32 bit accumulator (sum), we add
     * sequential 16 bit words to it, and at the end, fold back all the
     * carry bits from the top 16 bits into the lower 16 bits.
     */
    while (nleft > 1)  {
            sum += *w++;
            nleft -= 2;
    }

    if (nleft == 1) {
            *(u_char *)(&answer) = *(u_char *)w ;
            sum += answer;
    }

    sum = (sum >> 16) + (sum & 0xffff);
    sum += (sum >> 16);                
    answer = ~sum;                     
    return(answer);
}

read TCP function (old, check edited version)

/* packets are read using the pcap libraries */
void readTCP(const u_char *packets) {
   struct TCP_Header *tcp = (struct TCP_Header*) (packets + sizeof(struct Ethernet_Header) + sizeof(struct IP_Header));
   struct IP_Header *ip = (struct IP_Header*) (packets + sizeof(struct Ethernet_Header));
   struct TCP_Pseudo tcpPseudo;
   char tcpcsumblock[sizeof(struct TCP_Pseudo) + sizeof(struct TCP_Header)];

   /* tcp pseudo header */
   memset(&tcpPseudo, 0, sizeof(struct TCP_Pseudo));
   tcpPseudo.source_ip = ip->source_ip.s_addr;
   tcpPseudo.destination_ip = ip->destination_ip.s_addr;
   tcpPseudo.zero = 0;
   tcpPseudo.protocol = 6;
   tcpPseudo.length = htons(sizeof(struct TCP_Header));

   /* grab tcp checksum and reset it */
   int tcpCheckSum = htons(tcp->tcp_checksum);
   tcp->tcp_checksum = 0;

   /* place the data from the tcp pseudo infront of the tcp header */
   memcpy(tcpcsumblock, &tcpPseudo, sizeof(TCPPseudoHeader));   
   memcpy(tcpcsumblock+sizeof(TCPPseudoHeader),tcp, sizeof(TCPHeader));

   /* here is the issue, the checksum that i'm calculating isn't the correct checksum (i checked this by examing the packets from wireshark */
   u_short checksum = in_cksum((unsigned short *)tcpcsumblock, sizeof(tcpcsumblock));
}

==EDIT==

new tcp function

/* packets are read using the pcap libraries */
void readTCP(const u_char *packets) {
   struct TCP_Header *tcp = (struct TCP_Header*) (packets + sizeof(struct Ethernet_Header) + sizeof(struct IP_Header));
   struct IP_Header *ip = (struct IP_Header*) (packets + sizeof(struct Ethernet_Header));
   struct TCP_Pseudo tcpPseudo;

   /* tcp pseudo header */
   memset(&tcpPseudo, 0, sizeof(struct TCP_Pseudo));
   tcpPseudo.source_ip = ip->source_ip;
   tcpPseudo.destination_ip = ip->destination_ip;
   tcpPseudo.zero = 0;
   tcpPseudo.protocol = 6;
   tcpPseudo.len = htons(ip->ip_len - (ip->ip_hdr_len * 4));

   int len = sizeof(struct TCP_Pseudo) + tcpPseudo.len;
   u_char tcpcsumblock[len];

   memcpy(tcpcsumblock, &tcpPseudo, sizeof(struct TCP_Pseudo));
   memcpy(tcpcsumblock + sizeof(struct TCP_Pseudo), (packets + sizeof(struct Ethernet_Header) + sizeof(struct IP_Header)), tcpPseudo.len);

   /* here is the issue, the checksum that i'm calculating isn't the correct checksum (i checked this by examing the packets from wireshark */
   u_short checksum = in_cksum((unsigned short *)ps_tcp, len);
   char *cs = checksum ? "Invalid Checksum!" : "Valid!";
}

ip header

typedef struct IP_Header {
#if __BYTE_ORDER__ == __LITTLE_ENDIAN__
   uint8_t ip_hdr_len:4;   /* header length */
   uint8_t ip_version:4;   /* ip version */
#else
   uint8_t ip_version:4;   /* ip version */
   uint8_t ip_hdr_len:4;   /* The IP header length */
#endif

   uint8_t ip_tos;      /* type of service */
   uint16_t ip_len;     /* total length */
   uint16_t ip_id;      /* identification */
   uint16_t ip_off;     /* fragment offset field */
#define IP_DF 0x4000            /* dont fragment flag */
#define IP_MF 0x2000            /* more fragments flag */
#define IP_OFFMASK 0x1fff       /* mask for fragmenting bits */
   uint8_t  ip_ttl;     /* time to live */
   uint8_t  ip_p;       /* protocol */
   uint16_t ip_sum;     /* checksum */
   struct in_addr ip_src, ip_dst;   /* source and dest address */
} __attribute__ ((packed));

tcp header

typedef struct TCP_Header {
  uint16_t tcp_source_port; /* source port */
  uint16_t tcp_dest_port; /* destination port */
  uint32_t tcp_seq; /* sequence */   
  uint32_t tcp_ack; /* acknowledgement number */
  uint8_t tcp_offest; /* data offset */
#define TH_OFF(th)      (((th)->th_offx2 & 0xf0) >> 4)
  uint8_t tcp_flags; /* flags */
#define TH_FIN      0x01
#define TH_SYN      0x02
#define TH_RST      0x04
#define TH_PUSH     0x08
#define TH_ACK      0x10
#define TH_URG      0x20
#define TH_ECE      0x40
#define TH_CWR      0x80

#define TH_NS       0x100
#define TH_RS       0xE00

   uint16_t tcp_window; /* window */
   uint16_t tcp_sum; /* checksum */
   uint16_t tcp_urp; /* urgent pointer */
} __attribute__ ((packed));

tcp pseudo header

typedef struct TCP_Pseudo {
   struct in_addr src_ip; /* source ip */
   struct in_addr dest_ip; /* destination ip */
   uint8_t zeroes; /* = 0 */
   uint8_t protocol; /* = 6 */
   uint16_t len; /* length of TCPHeader */
} __attribute__ ((packed));

Answer

Mecki picture Mecki · Jan 19, 2013

The problem is this line:

tcpPseudo.length = htons(sizeof(struct TCP_Header));

According to RFC 793:

The TCP Length is the TCP header length plus the data length in octets (this is not an explicitly transmitted quantity, but is computed), and it does not count the 12 octets of the pseudo header.

You only set the TCP header length, but it should be TCP header length + data length.

The data length is the Total Length reported by the IP header minus the IP header length (field is named IHL in the IP header and must be multiplied by 4 to get the length in bytes) minus the size of the TCP header.

enter image description here

Yet since you want to add the length of the TCP header to the data length, you just have to subtract the IP header length form the total packet length and left over is the sum of TCP header and data length.

tcpPseudo.length = htons(ntohs(ip->total_length) - (ip->ihl * 4));

Also according to the RFC:

The checksum field is the 16 bit one's complement of the one's complement sum of all 16 bit words in the header and text.

"And text" means all the data following the TCP header is also checksummed. Otherwise TCP could not guarantee that the data was transmitted correctly. Remember, TCP is a reliable protocol, that will retransmit corrupted data, but therefor it must also recognize when data got corrupted.

So the whole packet minus the IP header must be added to tcpcsumblock. Therefor tcpcsumblock must be big enough for whole packets to fit (in case of TCP, 1500 bytes are usually enough, though in theory an IP packet may be as big as 64 KB and will be fragmented if needed) and then you must add the pseudo header, the tcp header and everything to the end of the packet.


I wrote a working piece of code for you. I verified that this works correctly by feeding in some real-live data. As a bonus, this implementation performs a couple of sanity checks on the IP header, including verifying its checksum (since if that doesn't match, all header fields may contain bogus as the header got most likely corrupted), and it also doesn't need to allocate any dynamic memory or copy any data around, so it should be pretty fast. Especially for the last feature, I had to change the in_cksum function to accept a third parameter. If this parameter is zero, it will behave exactly as the version in your code, but by feeding the correct value, you can use this function to update an already calculated checksum as if the data you are going to checksum had directly followed the data you already did checksum before (pretty nifty, huh?)

uint16_t in_cksum (const void * addr, unsigned len, uint16_t init) {
  uint32_t sum;
  const uint16_t * word;

  sum = init;
  word = addr;

  /*
   * Our algorithm is simple, using a 32 bit accumulator (sum), we add
   * sequential 16 bit words to it, and at the end, fold back all the
   * carry bits from the top 16 bits into the lower 16 bits.
   */

  while (len >= 2) {
    sum += *(word++);
    len -= 2;
  }

  if (len > 0) {
    uint16_t tmp;

    *(uint8_t *)(&tmp) = *(uint8_t *)word;
    sum += tmp;
  }

  sum = (sum >> 16) + (sum & 0xffff);
  sum += (sum >> 16);
  return ((uint16_t)~sum);
}


void readTCP (const u_char *packets) {
  uint16_t csum;
  unsigned ipHdrLen;
  unsigned ipPacketLen;
  unsigned ipPayloadLen;
  struct TCP_Pseudo pseudo;
  const struct IP_Header * ip;
  const struct TCP_Header * tcp;

  // Verify IP header and calculate IP payload length
  ip = (const struct IP_Header *)(packets + sizeof(struct Ethernet_Header));
  ipHdrLen = ip->ip_hdr_len * 4;
  if (ipHdrLen < sizeof(struct IP_Header)) {
    // Packet is broken!
    // IP packets must not be smaller than the mandatory IP header.
    return;
  }
  if (in_cksum(ip, ipHdrLen, 0) != 0) {
    // Packet is broken!
    // Checksum of IP header does not verify, thus header is corrupt.
    return;
  }
  ipPacketLen = ntohs(ip->ip_len);
  if (ipPacketLen < ipHdrLen) {
    // Packet is broken!
    // The overall packet cannot be smaller than the header.
    return;
  }
  ipPayloadLen = ipPacketLen - ipHdrLen;

  // Verify that there really is a TCP header following the IP header
  if (ip->ip_p != 6) {
      // No TCP Packet!
      return;
  }
  if (ipPayloadLen < sizeof(struct TCP_Header)) {
    // Packet is broken!
    // A TCP header doesn't even fit into the data that follows the IP header.
    return;
  }

  // TCP header starts directly after IP header
  tcp = (const struct TCP_Header *)((const u_char *)ip + ipHdrLen);

  // Build the pseudo header and checksum it
  pseudo.src_ip = ip->ip_src;
  pseudo.dest_ip = ip->ip_dst;
  pseudo.zeroes = 0;
  pseudo.protocol = 6;
  pseudo.len = htons(ipPayloadLen);
  csum = in_cksum(&pseudo, (unsigned)sizeof(pseudo), 0);

  // Update the checksum by checksumming the TCP header
  // and data as if those had directly followed the pseudo header
  csum = in_cksum(tcp, ipPayloadLen, (uint16_t)~csum);

  char * cs = csum ? "Invalid Checksum!" : "Valid!";
  printf("%s\n", cs);
}