// https://www2.cs.duke.edu/courses/fall16/compsci356/DNS/DNS-primer.pdf
// https://en.wikipedia.org/wiki/Domain_Name_System

// DNS Cache is a HashTable, similar to ARP Cache

#define DNS_HASHTABLE_SIZE      2048    // 1024 might be fine, test it
#define HTT_DNS                         0x00100 // identical to HTT_DICT_WORD

#define DNS_FLAG_RD             0x0100

#define DNS_OP_QUERY    0

#define DNS_TYPE_A              1

#define DNS_CLASS_IN    1

#define DNS_TIMEOUT             5000

#define DNS_MAX_RETRIES 5

class CDNSHash:CHash
{       // store U8 *hostname as CHash->str U8 *
        CAddressInfo info;
        // Shrine has 'TODO: honor TTL' ...
        // Duke: 'TTL: the number of seconds the results can be cached'
        // perhaps have a separate task for removing cached results ?
};

class CDNSDomainName
{
        U8 **labels;
        I64  num_labels;
};

class CDNSQuestion
{
        CDNSQuestion    *next;
        CDNSDomainName   q_name;
        U16                              q_type;
        U16                              q_class;
};

class CDNSHeader
{
        U16 id;
        U16 flags;
        U16 q_count;    // number of entries in question section
        U16 a_count;    // number of resource records in answer section
        U16 ns_count;   // number of name server resource records in authority records section
        U16 ar_count;   // number of resource records in additional records section
};

class CDNSRR
{ // RR: Resource Record
        CDNSRR                  *next;
        CDNSDomainName   name;          // name of the node this record is for
        U16                              type;          // RR type, e.g. 44=SSHFP, 15=MX, 49=DHCID ...
        U16                              rr_class;      // class code
        U32                              ttl;           // count in seconds that RR stays valid (max = 2^31 - 1)
        U16                              rd_length;     // length of r_data member
        U8                              *r_data;        // additional RR-specific data
};

class CDNSGlobals
{
        U16                                     addr_family;
        CIPAddressStorage       dns_ip;

} dns_globals;

CHashTable *dns_cache = NULL;

U0 DNSCacheInit()
{
        dns_cache = HashTableNew(DNS_HASHTABLE_SIZE);

        MemSet(&dns_globals.dns_ip, 0, sizeof(CIPAddressStorage));
        dns_globals.addr_family = 0;
}

CDNSHash *DNSCacheFind(U8 *hostname)
{
        CDNSHash *entry = HashFind(hostname, dns_cache, HTT_DNS);

        if (entry == NULL)
                NetWarn("DNS CACHE FIND: Could not find a hostname in the DNS Cache.");

        return entry;
}

CDNSHash *DNSCachePut(U8 *hostname, CAddressInfo *info)
{
        NetLog("DNS CACHE PUT: Attempting Find DNS Entry in Cache: hostname: %s", hostname);
        CDNSHash *entry = DNSCacheFind(hostname);

        if (!entry)
        {
                entry = CAlloc(sizeof(CDNSHash));
                entry->str      = StrNew(hostname);
                entry->type = HTT_DNS;

                AddressInfoCopy(&entry->info, info);

                HashAdd(entry, dns_cache);
        }
        else
        {
                NetWarn("DNS CACHE PUT: Entry was already found in Cache. Overwriting.");
                Free(entry->info.address);
                AddressInfoCopy(&entry->info, info);
        }

        return entry;
}

I64 DNSQuestionSizeCalculate(CDNSQuestion *q)
{ // ??
        I64 i, size = 0;

        for (i = 0; i < q->q_name.num_labels; i++)
        {
                size += 1 + StrLen(q->q_name.labels[i]);
        }

        return size + 1 + 4;
}

U0 DNSQuestionSerialize(U8 *buffer, CDNSQuestion *q)
{ // ??
        I64 i;
        U8 *label;

        for (i = 0; i < q->q_name.num_labels; i++)
        {
                label = q->q_name.labels[i];

                *buffer++ = StrLen(label);

                while (*label)
                        *buffer++ = *label++;
        }

        *buffer++ = 0;
        *buffer++ = q->q_type >> 8;
        *buffer++ = q->q_type & 0xFF;
        *buffer++ = q->q_class >> 8;
        *buffer++ = q->q_class & 0xFF;
}

I64 DNSQuestionSend(U16 id, U16 local_port, CDNSQuestion *q)
{
        CIPV4Address *ipv4_addr;
        U8                       *dns_frame;
        U16                       flags;
        CDNSHeader       *header;
        I64                       de_index;

        switch (dns_globals.addr_family)
        {
                case AF_UNSPEC: // 0, global dns ip not set
                        NetErr("DNS SEND QUESTION: Failed, global dns addr family was AF_UNSPEC.");
                        return -1;

                case AF_INET6:
                        NetErr("DNS SEND QUESTION: Failed, IPV6 not supported yet in DNS.");
                        throw('DNS');

                case AF_INET:
                        ipv4_addr = &dns_globals.dns_ip;

                        if (!*ipv4_addr)
                        {
                                NetErr("DNS SEND QUESTION: Failed, ipv4_addr had no value set.");
                                return -1;
                        }
        }

        // UDPPacketAllocate currently only accepts IPV4 ...
        de_index = UDPPacketAllocate(&dns_frame,
                                                                 IPV4AddressGet,
                                                                 local_port,
                                                                 *ipv4_addr,
                                                                 53,
                                                                 sizeof(CDNSHeader) + DNSQuestionSizeCalculate(q));
        if (de_index < 0)
        {
                NetErr("DNS SEND QUESTION: Failed, UDPPacketAllocate returned error.");
                return de_index;
        }

        flags = DNS_OP_QUERY << 11 | DNS_FLAG_RD;

        header = dns_frame;

        header->id                      = EndianU16(id);
        header->flags           = EndianU16(flags);
        header->q_count         = EndianU16(1);
        header->a_count         = 0;
        header->ns_count        = 0;
        header->ar_count        = 0;

        DNSQuestionSerialize(dns_frame + sizeof(CDNSHeader), q);

        UDPPacketFinish(de_index);
        return 0;

}


I64 DNSDomainNameParse(U8 *packet_data, I64 packet_length, U8 **data_inout, I64 *length_inout, CDNSDomainName *name_out)
{ // these methods look not-so-good, ngl.
        U8  *data = *data_inout;
        U8  *name_buf;
        I64  length = *length_inout;
        I64  label_len;
        Bool jump_taken = FALSE;

        if (length < 1)
        {
                NetErr("DNS PARSE DOMAIN NAME: Length less than one.");
                return -1;
        }

        name_out->labels                = CAlloc(16 * sizeof(U8 *));
        name_out->num_labels    = 0;

        name_buf = CAlloc(256); // ?..
        name_out->labels[0] = name_buf;

        while (length)
        {
                label_len = *data++;
                length--;

                if (label_len == 0)
                        break;
                else if (label_len >= 192)
                {
                        label_len &= 0x3F; // ...

                        if (!jump_taken)
                        {
                                *data_inout             = data   + 1;
                                *length_inout   = length - 1;
                                jump_taken              = TRUE;
                                NetLog("DNS PARSE DOMAIN NAME: Jump taken");
                        }

                        data    = packet_data + (label_len << 8 | *data);
                        length  = packet_data + packet_length - data;
                }
                else
                {
                        if (length < label_len)
                                return -1; // ?

                        MemCopy(name_buf, data, label_len);
                        data    += label_len;
                        length  -= label_len;

                        name_buf[label_len] = 0;
                        name_out->labels[name_out->num_labels++] = name_buf;

                        name_buf += label_len + 1;
                }
        }

        if (!jump_taken)
        {
                *data_inout             = data;
                *length_inout   = length;
        }

        return 0;
}


I64 DNSQuestionParse(U8 *packet_data, I64 packet_length, U8 **data_inout, I64 *length_inout, CDNSQuestion *q_out)
{
        U8 *data;
        I64 length;
        I64 error = DNSDomainNameParse(packet_data, packet_length, data_inout, length_inout, &q_out->q_name);

        if (error < 0)
                return error;

        data    = *data_inout;
        length  = *length_inout;

        if (length < 4)
                return -1;

        q_out->next             = NULL;
        q_out->q_type   = data[1] << 8 | data[0];
        q_out->q_class  = data[3] << 8 | data[2];

        *data_inout             = data   + 4;
        *length_inout   = length - 4;

        return 0;
}

I64 DNSRRParse(U8 *packet_data, I64 packet_length, U8 **data_inout, I64 *length_inout, CDNSRR *rr_out)
{
        U8 *data;
        I64 length;
        I64 record_length;
        I64 error = DNSDomainNameParse(packet_data, packet_length, data_inout, length_inout, &rr_out->name);

        if (error < 0)
                return error;

        data    = *data_inout;
        length  = *length_inout;

        if (length < 10)
                return -1;

        rr_out->next = NULL;
        MemCopy(&rr_out->type, data, 10); // ???

        record_length = 10 + EndianU16(rr_out->rd_length);

        if (length < record_length)
                return -1;

        rr_out->r_data = data + 10; // ??

        *data_inout             = data   + record_length;
        *length_inout   = length - record_length;

        return 0;
}

I64 DNSResponseParse(U16 id, U8 *data, I64 length, CDNSHeader **header_out, CDNSQuestion **questions_out, CDNSRR **answers_out)
{
        CDNSHeader              *header;
        CDNSQuestion    *question;
        CDNSRR                  *answer;
        I64                              i;
        U8                              *packet_data    = data;
        I64                              packet_length  = length;

        if (length < sizeof(CDNSHeader))
        {
                NetErr("DNS PARSE RESPONSE: Length too short.");
                return -1;
        }

        header = data;
        data += sizeof(CDNSHeader);

        if (id != 0 && EndianU16(header->id) != id)
        {
                NetErr("DNS PARSE RESPONSE: Header ID mismatch.");
                return -1;
        }

        for (i = 0; i < EndianU16(header->q_count); i++)
        {
                question = CAlloc(sizeof(CDNSQuestion));
                if (DNSQuestionParse(packet_data, packet_length, &data, &length, question) < 0)
                        return -1;

                question->next = *questions_out;
                *questions_out = question;
        }

        for (i = 0; i < EndianU16(header->a_count); i++)
        {
                answer = CAlloc(sizeof(CDNSRR));
                if (DNSRRParse(packet_data, packet_length, &data, &length, answer) < 0)
                        return -1;

                answer->next = *answers_out;
                *answers_out = answer;
        }

        *header_out = header;
        return 0;

}

U0 DNSQuestionBuild(CDNSQuestion *q, U8 *name)
{
        U8 *copy = StrNew(name);
        U8 *dot;

        q->next                                 = NULL;
        q->q_name.labels                = CAlloc(16 * sizeof(U8 *));
        q->q_name.labels[0]             = 0;
        q->q_name.num_labels    = 0;
        q->q_type                               = DNS_TYPE_A;
        q->q_class                              = DNS_CLASS_IN;

        while (*copy)
        {
                q->q_name.labels[q->q_name.num_labels++] = copy;
                dot = StrFirstOcc(copy, ".");

                if (dot)
                {
                        *dot = 0;
                        copy = dot + 1;
                }
                else
                        break;
        }
}

// these Free methods bother me a bit...
U0 DNSQuestionFree(CDNSQuestion *q)
{
        Free(q->q_name.labels[0]);
}

U0 DNSRRFree(CDNSRR *rr)
{
        Free(rr->name.labels[0]);
}

U0 DNSQuestionChainFree(CDNSQuestion *questions)
{
        CDNSQuestion *next;

        while (questions)
        {
                next = questions->next;
                DNSQuestionFree(questions);
                Free(questions);
                questions = next;
        }
}

U0 DNSRRChainFree(CDNSRR *rrs)
{ // Shrine sets rrs->next to a CDNSQuestion when it would be a CDNSRR ... assuming it's wrong and fixing it here..
        CDNSRR *next;

        while (rrs)
        {
                next = rrs->next;
                DNSRRFree(rrs);
                Free(rrs);
                rrs = next;
        }
}


I64 DNSQueryRun(CUDPSocket *udp_socket, U8 *name, U16 port, CAddressInfo **result_out)
{       // IPV4-UDP-based, TODO: take good look at this method to ensure no floating pointers after.
        // note: UDP Socket created in this method is not closed in this method, gets closed e.g. in DNSAddressInfoGet
        I64  retries    = 0;
        I64      timeout        = DNS_TIMEOUT;
        U16      local_port     = MaxI64(1024, RandU16); // Pick a random port above 1023. (within standard application port range)
        U16  id                 = RandU16;
        I64  error              = 0;
        U8   buffer[2048];
        I64  count;
        Bool have;      // ??

        CDNSQuestion     q;
        CDNSHeader              *header;
        CDNSQuestion    *questions;
        CDNSRR                  *answers;
        CDNSRR                  *a;

        CSocketAddressIPV4       ipv4_addr;
        CSocketAddressIPV4       ipv4_addr_in; // ?
        CSocketAddressIPV4      *ipv4_addr_temp;
        CAddressInfo            *res;

        //setsockopt(socket, SOL_SOCKET, SO_RCVTIMEO_MS, &timeout, sizeof(timeout))
        udp_socket->receive_timeout_ms = timeout;

        ipv4_addr.family                        = AF_INET;
        ipv4_addr.port                          = EndianU16(local_port);
        ipv4_addr.address.address       = INADDR_ANY;

        // UDPSocketBind will be attempted on the udp_socket param, method expects a UDPSocket() result to be made already
        if (UDPSocketBind(udp_socket, &ipv4_addr)) // expected return value is 0
        {
                NetErr("DNS RUN QUERY: Failed to bind UDP socket.");
                return -1;
        }

        DNSQuestionBuild(&q, name);

        while (TRUE) // Shrine uses while (1) infinite loop, need to be careful not to lock
        {
                error = DNSQuestionSend(id, local_port, &q);
                if (error < 0)
                {
                        NetErr("DNS RUN QUERY: Failed to Send Question.");
                        return -1;
                }

                count = UDPSocketReceiveFrom(udp_socket, buffer, sizeof(buffer), &ipv4_addr_in);

                if (count > 0)
                {
                        NetLog("DNS RUN QUERY: Trying Parse Response.");

                        header          = NULL;
                        questions       = NULL;
                        answers         = NULL;

                        error = DNSResponseParse(id, buffer, count, &header, &questions, &answers);

                        if (error == 0) // Shrine has (error >= 0), but DNSResponseParse can only return 0 or 1 ..
                        {
                                have = FALSE;

                                a = answers;
                                while (a)
                                {
                                        // Shrine has TODO: if multiple acceptable answers, pick one at random, not just first one.
                                        // perhaps we could use r_count in header for that ?

                                        if (EndianU16(a->type)          == DNS_TYPE_A   &&
                                                EndianU16(a->rr_class)  == DNS_CLASS_IN &&
                                                EndianU16(a->rd_length) == 4)
                                        {
                                                res = CAlloc(sizeof(CAddressInfo));

                                                res->flags                      = 0;
                                                res->family                     = AF_INET;
                                                res->socket_type        = 0;    // ??
                                                res->protocol           = 0;    // ??
                                                res->address_length     = sizeof(CSocketAddressIPV4);
                                                res->address            = CAlloc(sizeof(CSocketAddressIPV4));
                                                res->canonical_name     = 0;
                                                res->next                       = NULL;

                                                ipv4_addr_temp = res->address;

                                                ipv4_addr_temp->family  = AF_INET;
                                                ipv4_addr_temp->port    = port;
                                                MemCopy(&ipv4_addr_temp->address.address, answers->r_data, 4);

                                                DNSCachePut(name, res);
                                                *result_out = res;
                                                have = TRUE;

                                                break;
                                        }

                                        a = a->next;
                                }

                                DNSQuestionChainFree(questions);
                                DNSRRChainFree(answers);

                                if (have)
                                        break;

                                // Shrine comment: 'at this point, we could try iterative resolution,
                                // but all end-user DNS servers would have tried that already'

                                NetErr("DNS RUN QUERY: Failed to find suitable answer in reply.");
                                error = -1;
                        }
                        else
                        {
                                NetErr("DNS RUN QUERY: Failed a DNS Parse Response.");
                        }
                }

                if (++retries == DNS_MAX_RETRIES)
                {
                        NetErr("DNS RUN QUERY: Failed, max retries reached.");
                        error = -1;
                        break;
                }
        }

        DNSQuestionFree(&q);
        return error;
}

// Shrine has port arg as U8 *service with a no_warn and says it should be parsed as port, allowing that here
// Also has CAddressInfo *hints with a no_warn, omitting that for now
I64 DNSAddressInfoGet(U8 *node_name, U16 port, CAddressInfo **result)
{
        I64                      error;
        CUDPSocket      *udp_socket;
        CDNSHash        *cached_entry = DNSCacheFind(node_name);

        if (cached_entry)
        {
                *result = CAlloc(sizeof(CAddressInfo));
                AddressInfoCopy(*result, &cached_entry->info);
                //(*res)->flags |= AI_CACHED; // TODO: add AI_CACHED define (maybe a better name?) not used anywhere i don't think..
                return 0;
        }

        udp_socket = UDPSocket(AF_INET);
        error = 0;

        if (udp_socket)
        {
                error = DNSQueryRun(udp_socket, node_name, port, result);

                UDPSocketClose(udp_socket);
        }
        else
        {
                NetErr("DNS GET ADDRESS INFO: Failed to make UDP Socket.");
                error = -1;
        }

        return error;
}

U0 DNSResolverIPV4Set(U32 ip)
{
        CIPV4Address *address = &dns_globals.dns_ip;

        dns_globals.addr_family = AF_INET;
        address->address                = ip;
}

U0 Host(U8 *hostname)
{ // getaddrinfo() for whole system in Shrine ends up as pointer to DNSAddressInfoGet.. should we do something similar?
        CAddressInfo            *current;
        CAddressInfo            *result = NULL;
        I64                                      error  = DNSAddressInfoGet(hostname, NULL, &result);
        I64                                      i = 0;
        CSocketAddressIPV4      *ipv4_address;

        if (error < 0)
        {
                NetErr("HOST(): Failed at DNS Get Address Info.");
        }
        else
        {
                "Results:\n\n";
                current = result;
                while (current)
                {
                        "Result %d:\n", ++i;

                        "       flags:                  0x%04X \n", current->flags;
                        "       family:                 %d    \n", current->family;
                        "       socket type:    %d    \n", current->socket_type;
                        "       protocol:               %d    \n", current->protocol;
                        "       address length: %d    \n", current->address_length;
                        switch (current->family)
                        {
                                case AF_INET:
                                        ipv4_address = current->address;
                                        "       address:                %s    \n", NetworkToPresentation(AF_INET, &ipv4_address->address);
                                        break;

                                case AF_INET6:
                                        "       address:                IPV6    \n"; // FIXME
                                        break;

                                case AF_UNSPEC:
                                        "       address:                AF_UNSPEC    \n";
                                        break;

                                default:
                                        "       address:                INVALID    \n";
                                        break;
                        }

                        current = current->next;
                }
                "\n";
        }

        AddressInfoFree(result);
}

U0 DNSRep()
{
        I64                                      i;
        CDNSHash                        *temp_hash;
        CSocketAddressIPV4      *ipv4_address;

        "$LTBLUE$DNS Report:$FG$\n\n";
        for (i = 0; i <= dns_cache->mask; i++)
        {
                temp_hash = dns_cache->body[i];

                while (temp_hash)
                {
                        "DNS Hash @ 0x%X:\n", temp_hash;
                        "       Hostname:               %s\n", temp_hash->str;

                        switch (temp_hash->info.family)
                        {
                                case AF_INET:
                                        ipv4_address = temp_hash->info.address;

                                        "       IP Address:             %s\n",
                                                NetworkToPresentation(temp_hash->info.family,
                                                                                         &ipv4_address->address);
                                        break;
                                case AF_INET6:
                                        "       IP Address:             IPV6\n"; // FIXME
                                        break;

                                case AF_UNSPEC:
                                        "       IP Address:             AF_UNSPEC";
                                        break;

                                default:
                                        "       IP Address:             INVALID";
                                        break;
                        }

                        "\n";
                        temp_hash = temp_hash->next;
                }
        }
}

DNSCacheInit;