using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;

namespace DnsForwarder
{
    internal class Program
    {
        private static string _defaultDns;
        private static string _cnDns;
        private static IReadOnlyDictionary<string, int> _chinaList;
        private static string[] _chinaListDns;

        private static Socket _listener;
        private static BlockingCollection<string> _consoleOutout;

        private static void Main(string[] args)
        {
            if (args.Length != 4)
            {
                Console.WriteLine("<listen address> <default dns server ip> <.CN dns server ip> <path to dnsmasq-china-list file>");
                Environment.Exit(-1);
                return;
            }

            if (!IPAddress.TryParse(args[0], out var listenAddress))
            {
                Console.WriteLine("Invalid listen address.");
                Environment.Exit(-2);
                return;

            }

            Console.WriteLine("Starting...");
            Console.WriteLine($"Listen address: {args[0]}");
            Console.WriteLine($"Default Server: {_defaultDns = args[1]}");
            Console.WriteLine($".CN DNS Server: {_cnDns = args[2]}");
            Console.WriteLine($"dnsmasq-china-list: {args[3]}");

            Console.WriteLine("Loading list file...");
            LoadListFile(args[3]);
            Console.WriteLine("OK. Listening...");

            _listener = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
            _listener.Bind(new IPEndPoint(listenAddress, 53));

            _consoleOutout = new BlockingCollection<string>();

            StartHandleRequest();

            while (true)
            {
                Console.WriteLine(_consoleOutout.Take());
            }
        }

        private static void StartHandleRequest()
        {
            var buf = new byte[1500];
            EndPoint from = new IPEndPoint(IPAddress.Any, 0);
            _listener.BeginReceiveFrom(buf, 0, buf.Length, SocketFlags.None, ref from, Callback, buf);
        }

        private static void Callback(IAsyncResult ar)
        {
            var flagNextStarted = false;
            EndPoint from = new IPEndPoint(IPAddress.Any, 0);
            string domain = null;
            string target = null;

            try
            {
                var count = _listener.EndReceiveFrom(ar, ref from);

                StartHandleRequest();
                flagNextStarted = true;

                var buf = (byte[])ar.AsyncState;

                domain = ExtractDomainName(buf);

                target = MatchServer(domain);

                var dnsResponse = GetDnsResponse(buf, count, target);
                if (dnsResponse != null) _listener.SendTo(dnsResponse, from);
                _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target}]\t{domain}");
            }
            catch (Exception e)
            {
                _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target ?? "Unknown"}]\t{domain ?? "Unknown"} Err:{e.Message}");
            }
            catch
            {
                _consoleOutout.Add($"{DateTime.Now:yyyyMMdd HH:mm:ss} {from} [{target ?? "Unknown"}]\t{domain ?? "Unknown"} Err:Unknown");
            }
            finally
            {
                if (flagNextStarted == false) StartHandleRequest();
            }
        }

        protected static string ExtractDomainName(byte[] buf)
        {
            //seeking for end of domain
            var ptr = 12;
            while (buf[ptr] != 0)
            {
                ptr += buf[ptr] + 1;
            }

            var bufDomain = new byte[ptr - 12];
            Array.Copy(buf, 12, bufDomain, 0, bufDomain.Length);

            //fill dots
            ptr = 0;
            while (ptr < bufDomain.Length)
            {
                var b = bufDomain[ptr];
                bufDomain[ptr] = (byte)'.';
                ptr += b + 1;
            }

            return Encoding.ASCII.GetString(bufDomain, 1, bufDomain.Length - 1);
        }

        protected static byte[] GetDnsResponse(byte[] buf, int count, string host, int port = 53)
        {
            using var to = new UdpClient();
            to.Connect(host, 53);
            to.Send(buf, count);

            //Handle Upstream TimeOut

            var asyncResult = to.BeginReceive(null, null);
            asyncResult.AsyncWaitHandle.WaitOne(2000);
            if (asyncResult.IsCompleted)
            {
                IPEndPoint remoteEP = null;
                byte[] receivedData = to.EndReceive(asyncResult, ref remoteEP);
                return receivedData;
            }

            return null;
        }

        protected static string MatchServer(string domain)
        {
            var lower = domain.ToLower();
            if (lower.EndsWith(".cn")) return _cnDns;

            var parts = lower.Split('.').Reverse().ToArray();

            for (int i = parts.Length - 1; i >= 0; i--)
            {
                var d = string.Join(".", parts.Take(i + 1).Reverse());
                if (_chinaList.TryGetValue(d, out var tar)) return _chinaListDns[tar];
            }

            return _defaultDns;
        }

        protected static void LoadListFile(params string[] paths)
        {
            var lines = paths.SelectMany(File.ReadAllLines);

            var dic = new Dictionary<string, int>();
            var tar = new List<string>();

            foreach (var line in lines)
            {
                var p = line.Trim();
                if (p.StartsWith("#")) continue;

                var parts = p.Split('/');
                if (parts.Length != 3 && parts[0] != "server=") continue;

                var domain = parts[1];
                var dns = parts[2];

                var dnsIndex = tar.IndexOf(dns);
                if (dnsIndex == -1)
                {
                    tar.Add(dns);
                    dnsIndex = tar.Count - 1;
                }

                dic[domain] = dnsIndex;
            }

            _chinaListDns = tar.ToArray();
            _chinaList = new ReadOnlyDictionary<string, int>(dic);
        }
    }
}