package examples.net.udp;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Created by IntelliJ IDEA.
 * User: geo
 * Date: 28.04.2010
 * Time: 8:38:24
 * To change this template use File | Settings | File Templates.
 */
public class DNSResponse {
    private final int id;
    private final int header;
    private final List<Question> questions;
    private final List<Record> answers;
    private final List<Record> nameServers;
    private final List<Record> resources;

    private static final class CountingDataInputStream extends DataInputStream {
        private static final class CIS extends FilterInputStream {
            private int position;
            private CIS(final InputStream in) {
                super(in);
            }

            @Override
            public int read() throws IOException {
                position++;
                return super.read();
            }

            @Override
            public int read(final byte[] b, final int off, final int len) throws IOException {
                int read = super.read(b, off, len);
                position += read;
                return read;
            }
        }
        private CountingDataInputStream(final InputStream in) {
            super(new CIS(in));
        }

        public int getPosition() {
            return ((CIS) in).position;
        }
    }

    public DNSResponse(byte[] buffer, int offset, int length) {
        try {
            final ByteArrayInputStream bais = new ByteArrayInputStream(buffer, offset, length);
            try {
                CountingDataInputStream dis = new CountingDataInputStream(bais);

                try {
                    id = dis.readShort();
                    header = dis.readUnsignedShort();
                    int questionCount = dis.readUnsignedShort();
                    int answerCount = dis.readUnsignedShort();
                    int nameServerCount = dis.readUnsignedShort();
                    int resourceCount = dis.readUnsignedShort();
                    
                    final Map<Integer, String> names = new HashMap<Integer, String>();
                    this.questions = readQuestions(questionCount, dis, names);
                    this.answers = readRecords(answerCount, dis, names);
                    this.nameServers = readRecords(nameServerCount, dis, names);
                    this.resources = readRecords(resourceCount, dis, names);
                } finally {
                    dis.close();
                }
            } finally {
                bais.close();
            }
        } catch (IOException e) {
            throw new AssertionError(e);
        }
    }

    private List<Record> readRecords(
            final int n, final CountingDataInputStream dis, final Map<Integer, String> names
    ) throws IOException {
        final List<Record> records = new ArrayList<Record>();
        for (int i = 0; i < n; i++) {
            records.add(readRecord(dis, names));
        }
        return records;
    }

    private List<Question> readQuestions(
            final int questionCount, final CountingDataInputStream dis,
            final Map<Integer, String> names
    ) throws IOException {
        List<Question> questions = new ArrayList<Question>();
        for (int i = 0; i < questionCount; i++) {
            questions.add(readQuestion(dis, names));
        }
        return questions;
    }

    private Record readRecord(final CountingDataInputStream dis, final Map<Integer, String> names) throws IOException {
        final Header header = readHeader(dis, names);
        final int ttl = dis.readInt();
        final int dataLength = dis.readShort();

        switch (header.type) {
            case 1:
                return new ARecord(header, readFully(dis, dataLength));
            case 2:
                return new NSRecord(header, readLabel(dis, names));
            case 5:
                return new CNameRecord(header, readLabel(dis, names));
            default:
                return new UnknownRecord(header, readFully(dis, dataLength));
        }
    }

    private Header readHeader(final CountingDataInputStream dis, final Map<Integer, String> names) throws IOException {
        final String name = readLabel(dis, names);
        final int type = dis.readUnsignedShort();
        final int clazz = dis.readUnsignedShort();
        return new Header(name, type, clazz);
    }

    private Question readQuestion(final CountingDataInputStream dis, final Map<Integer, String> names) throws IOException {
        return new Question(readHeader(dis, names));
    }

    private String readLabel(final CountingDataInputStream dis, final Map<Integer, String> names) throws IOException {
        final int position = dis.getPosition();
        final int length = dis.readUnsignedByte();
        if (length == 0) {
            return "";
        }
        if (length > 63) {
            return names.get(((length & 63) << 8) + dis.readUnsignedByte());
        }
        final String name = new String(readFully(dis, length), "US-ASCII") + "." + readLabel(dis, names);
        names.put(position, name);
        return name;
    }

    private byte[] readFully(final DataInputStream dis, final int length) throws IOException {
        byte[] buffer = new byte[length];
        dis.readFully(buffer);
        return buffer;
    }

//    @Override
//    public String toString() {
//        return ToStringBuilder.reflectionToString(this, ToStringStyle.MULTI_LINE_STYLE);
////        return String.format(
////                "DNSResponse{ID=%d, QR=%d, Opcode=%d, AA=%d, TC=%d, RD=%d, RA=%d, Z=%d, RCODE=%d, QDCOUNT=%d, ANCOUNT=%d, NSCOUNT=%d, ARCOUNT=%d}",
////                id,
////                getBits(header, 15, 1),
////                getBits(header, 11, 4),
////                getBits(header, 10, 1),
////                getBits(header, 9, 1),
////                getBits(header, 8, 1),
////                getBits(header, 7, 1),
////                getBits(header, 4, 3),
////                getBits(header, 0, 4),
////                questionCount,
////                answerCount,
////                nameServerCount,
////                resourceCount
////        );
//    }

    private int getBits(int field, int offset, int length) {
        return (field >>> offset) & ((1 << length) - 1);
    }

    private static class Header {
        private final String name;
        private final int type;
        private final int clazz;

        public Header(final String name, final int type, final int clazz) {
            this.name = name;
            this.type = type;
            this.clazz = clazz;
        }

//        @Override
//        public String toString() {
//            return ToStringBuilder.reflectionToString(this, ToStringStyle.SHORT_PREFIX_STYLE);
//        }
    }

    public static class Record {
        private final String name;
        private final int type;
        private final int clazz;

        public Record(final Header header) {
            this.name = header.name;
            this.type = header.type;
            this.clazz = header.clazz;
        }
//
//        @Override
//        public String toString() {
//            return ToStringBuilder.reflectionToString(this, ToStringStyle.SHORT_PREFIX_STYLE);
//        }
    }

    public static class Question extends Record {
        public Question(final Header header) {
            super(header);
        }
    }

    public static class CNameRecord extends Record {
        private final String cname;

        public CNameRecord(final Header header, final String cname) {
            super(header);
            this.cname = cname;
        }
    }

    public static class NSRecord extends Record {
        private final String nsdname;

        public NSRecord(final Header header, final String nsdname) {
            super(header);
            this.nsdname = nsdname;
        }
    }

    public static class ARecord extends Record {
        private final String address;

        public ARecord(final Header header, final byte[] address) {
            super(header);
            this.address = (address[0] & 0xff) + "." + (address[1] & 0xff) + "." + (address[2] & 0xff) +  "." + (address[3] & 0xff);
        }
    }

    public static class UnknownRecord extends Record {
        private final byte[] data;

        public UnknownRecord(final Header header, final byte[] data) {
            super(header);
            this.data = data;
        }
    }
}
