import struct
import numpy as np
import time
from .common import IFFrame
from .common import SpectralFrame
from .timestamp import Timestamp


def read_u32(data: bytes, start_byte: int = 0) -> int:
    """Read an unsigned 32-bit integer (big-endian) from an array of bytes from a specified starting byte position"""
    return struct.unpack(">I", data[start_byte:start_byte+4])[0]

def read_i16(data: bytes, start_byte: int = 0) -> int:
    """Read a signed 16-bit integer (big-endian) from an array of bytes from a specified starting byte position"""
    return struct.unpack(">h", data[start_byte:start_byte+2])[0]

def read_i32(data: bytes, start_byte: int = 0) -> int:
    """Read a signed 32-bit integer (big-endian) from an array of bytes from a specified starting byte position"""
    return struct.unpack(">i", data[start_byte:start_byte+4])[0]

def read_i64(data: bytes, start_byte: int = 0) -> int:
    """Read a signed 64-bit integer (big-endian) from an array of bytes from a specified starting byte position"""
    return struct.unpack(">q", data[start_byte:start_byte+8])[0]

def read_f64(data: bytes, start_byte: int = 0) -> float:
    """Read a 64-bit floating-point value (big-endian) from an array of bytes from a specified starting byte position"""
    return struct.unpack('>d', data[start_byte:start_byte+8])[0]

def read_word(data: bytes, word_index: int) -> int:
    """
    Read a 32-bit VITA-49 word into an integer (unsigned)

    :param data: Array of bytes to read the word from
    :param word_index: Index of the word to read (index of 32-bit word, not byte index)
    :return: word integer
    """
    # VITA-49 uses big-endian byte-ordering for all words.  Use struct to apply the proper byte reordering.
    return read_u32(data, word_index * 4)

def read_uint_from_word(word: int, start_index: int, size: int) -> int:
    """
    Read a subset of bits in a VITA-49 word into an unsigned integer

    :param word: VITA-49 word to read from
    :param start_index: Bit index to start reading from (indexing is from LSB to MSB)
    :param size: Number of bits to read
    :return: integer read from the word
    """
    return (word >> start_index) & (~(~0x0 << size))

def read_bit_from_word(word: int, bit_index: int) -> bool:
    """
    Reads a single bit from a VITA-49 word

    :param word: VITA-49 word to read from
    :param bit_index: Bit index to start reading from (indexing is from LSB to MSB)
    :return:
    """
    return ((word >> bit_index) & 0x1) != 0

def read_frequency(data:bytes, start_byte: int = 0) -> float:
    """
    Parses a frequency value from a byte buffer

    VITA-49 encodes a number of frequency fields as a 64-bit fixed-point decimal number with 20 bits of fractional
    resolution and 42 bits of whole number resolution.

    :param data: Byte buffer to read from
    :param start_byte: Byte index to start reading the frequency from
    :return: Frequency
    """
    return read_i64(data, start_byte) / 1048576.0

def parse_sample_format(word1: int, word2: int):
    """
    Parses the data format field from a VITA-49 context packet

    The data format context field contains a number of details about the sample format of the digital IF data contained
    in a VRT packet stream.  Sceptre only makes use of a subset of the possibilities this field can describe, namely
    real or complex, integer or floating-point unpacked samples.  Sceptre does not support more advanced layouts
    like link-efficient bit packing or vectors of samples, so those are not handled by this function.

    :param word1: First word of the data format context field
    :param word2: Second word of the data format context field
    :return: The numpy.dtype representation of the sample format
    """

    # See section 7.1.5.18

    # All Sceptre uses to define the sample format are these three fields:
    real_cpx = read_uint_from_word(word1, 29, 2)
    dif = read_uint_from_word(word1, 24, 5)
    data_bits = read_uint_from_word(word1, 0, 5) + 1  # Data Item Size

    # Based on those three fields, return the correct numpy dtype
    if dif == 0:
        # Signed Integer
        if real_cpx == 0:
            if data_bits == 8:
                return np.dtype(np.int8)
            elif data_bits == 16:
                return np.dtype(np.int16)
            elif data_bits == 32:
                return np.dtype(np.int32)
            elif data_bits == 64:
                return np.dtype(np.int64)
        else:
            if data_bits == 8:
                return np.dtype([('re', np.int8), ('im', np.int8)])
            elif data_bits == 16:
                return np.dtype([('re', np.int16), ('im', np.int16)])
            elif data_bits == 32:
                return np.dtype([('re', np.int32), ('im', np.int32)])
            elif data_bits == 64:
                return np.dtype([('re', np.int64), ('im', np.int64)])
    elif dif == 16:
        # Unsigned Integer
        if real_cpx == 0:
            if data_bits == 8:
                return np.dtype(np.uint8)
            elif data_bits == 16:
                return np.dtype(np.uint16)
            elif data_bits == 32:
                return np.dtype(np.uint32)
            elif data_bits == 64:
                return np.dtype(np.uint64)
        else:
            if data_bits == 8:
                return np.dtype([('re', np.uint8), ('im', np.uint8)])
            elif data_bits == 16:
                return np.dtype([('re', np.uint16), ('im', np.uint16)])
            elif data_bits == 32:
                return np.dtype([('re', np.uint32), ('im', np.uint32)])
            elif data_bits == 64:
                return np.dtype([('re', np.uint64), ('im', np.uint64)])
    elif dif == 14:
        # 32-bit Float
        if real_cpx == 0:
            return np.dtype(np.float32)
        else:
            return np.dtype(np.complex64)
    elif dif == 15:
        # 64-bit Double
        if real_cpx == 0:
            return np.dtype(np.float64)
        else:
            return np.dtype(np.complex128)

    return np.dtype(np.uint8)

def convert_to_float(x: np.ndarray):
    """
    Converts a NumPy array to a floating point representation while maintaining whether it is real or complex

    This function is useful for converting the raw data from a VRT packet to a standard floating point format that is
    simpler to work with in NumPy.  This is especially useful for complex integer formats which are not natively
    supported by NumPy.

    :param x: NumPy array of input samples
    :return: Floating point NumPy array of output samples
    """
    # We are defining our own dtypes for complex integer formats since NumPy does not natively support them, so we need
    # to create a non-complex view of the data, convert that to floating point, and then create a new view on the result
    # to interpret it as complex.
    if x.dtype == np.dtype([('re', np.int8), ('im', np.int8)]):
        return x.view(np.int8).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.int16), ('im', np.int16)]):
        return x.view(np.int16).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.int32), ('im', np.int32)]):
        return x.view(np.int32).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.int64), ('im', np.int64)]):
        return x.view(np.int64).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.uint8), ('im', np.uint8)]):
        return x.view(np.uint8).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.uint16), ('im', np.uint16)]):
        return x.view(np.uint16).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.uint32), ('im', np.uint32)]):
        return x.view(np.uint32).astype(np.float32).view(np.complex64)
    elif x.dtype == np.dtype([('re', np.uint64), ('im', np.uint64)]):
        return x.view(np.uint64).astype(np.float32).view(np.complex64)
    elif x.dtype == np.complex64:
        return x
    elif x.dtype == np.complex128:
        return x.astype(np.complex64)
    return x.astype(np.float32)

class ECEFEphemeris:
    """
    Represents the ephemeris of the receiver in ECEF coordinates
    """
    timestamp: Timestamp
    """Timestamp associated with this position"""
    manufacturer_oui: int
    """Manufacturer OUI code"""
    position: tuple[float, float, float] | None
    """Position (x, y, z)"""
    attitude: tuple[float, float, float] | None
    """Attitude (alpha, beta, phi)"""
    velocity: tuple[float, float, float] | None
    """Velocity (dx/dt, dy/dt, dz/dt)"""

    def __init__(self):
        self.timestamp = Timestamp()
        self.manufacturer_oui = 0
        self.position = None
        self.attitude = None
        self.velocity = None

    def parse(self, payload: bytes) -> int:
        """
        Parses the context ECEF ephemeris field from an array of bytes into this object

        :param payload:  Array of bytes to parse the ephemeris from
        :return: The number of words parsed
        """
        nwords = len(payload) // 4
        if nwords < 13:
            return 0

        word = read_word(payload, 0)
        tsi = read_uint_from_word(word, 26, 2)
        tsf = read_uint_from_word(word, 24, 2)
        self.manufacturer_oui = read_uint_from_word(word, 0, 24)

        # Sceptre only uses TSI=1 (UTC seconds) and TSF=2 (fractional picoseconds) so only parse time if it is in that
        # format
        if tsi == 1 and tsf == 2:
            integer_timestamp = read_word(payload, 1)
            upper = read_word(payload, 2)
            lower = read_word(payload, 3)
            fractional_timestamp = (upper << 32) | lower
            self.timestamp = Timestamp(integer_timestamp, float(fractional_timestamp) / 1e12)
        else:
            self.timestamp = Timestamp()

        # Read position
        nullval = 0x7fffffff
        x = read_i32(payload[16:20])
        y = read_i32(payload[20:24])
        z = read_i32(payload[24:28])
        if x != nullval and y != nullval and z != nullval:
            self.position = (x / 32.0, y / 32.0, z / 32.0)

        # Read attitude
        alpha = read_i32(payload[28:32])
        beta = read_i32(payload[32:36])
        phi = read_i32(payload[36:40])
        if alpha != nullval and beta != nullval and phi != nullval:
            self.attitude = (alpha / 4194304.0, beta / 4194304.0, phi / 4194304.0)

        # Read velocity
        dxdt = read_i32(payload[40:44])
        dydt = read_i32(payload[44:48])
        dzdt = read_i32(payload[48:52])
        if dxdt != nullval and dydt != nullval and dzdt != nullval:
            self.velocity = (dxdt / 65536.0, dydt / 65536.0, dzdt / 65536.0)

        return 13


class VRTHeader:
    """
    Represents the contents of a VRT header
    """
    packet_type: int
    """Packet Type (section 6.1.1, table 6.1.1-1)"""
    has_class_id: bool
    """Whether or not the packet contains a class ID field"""
    has_trailer: bool
    """Whether or not the packet contains a trailer word"""
    tsi: int
    """Timestamp integer format (table 6.1.1-2)"""
    tsf: int
    """Timestamp fractional format (table 6.1.1-3)"""
    counter: int
    """Packet counter (modulo 16)"""
    packet_size: int
    """Packet size in 32-bit words"""
    stream_id: int
    """Stream identifier"""
    timestamp: Timestamp
    """Timestamp of packet"""

    def __init__(self):
        self.packet_type = 0
        self.has_class_id = False
        self.has_trailer = False
        self.tsi = 0
        self.tsf = 0
        self.counter = 0
        self.packet_size = 0
        self.stream_id = 0
        self.class_oui = 0
        self.information_class_code = 0
        self.packet_class_code = 0
        self.timestamp = Timestamp()

    def parse(self, packet: bytes) -> int:
        """
        Parses the VRT header from an array of bytes into this object

        :param packet:  Array of bytes to parse the header from
        :return: The number of header words in the packet or zero if the header could not be properly parsed
        """
        word_index = 0
        nwords = len(packet) // 4

        if word_index + 1 > nwords:
            return 0

        # See section 6.1.1
        # Read the VRT header (Figure 6.1.1-1)
        word = read_word(packet, word_index)
        word_index += 1
        self.packet_type = read_uint_from_word(word, 28, 4)
        self.has_class_id = read_bit_from_word(word, 27)
        self.has_trailer = read_bit_from_word(word, 26)
        self.tsi = read_uint_from_word(word, 22, 2)
        self.tsf = read_uint_from_word(word, 20, 2)
        self.counter = read_uint_from_word(word, 16, 4)
        self.packet_size = read_uint_from_word(word, 0, 16)

        # Read the stream ID if the packet_type indicates it contains one
        if self.packet_type in [1, 3, 4, 5]:
            if word_index + 1 > nwords:
                return 0
            self.stream_id = read_word(packet, word_index)
            word_index += 1

        # Read the class ID if it is present
        if self.has_class_id:
            if word_index + 2 > nwords:
                return 0
            word = read_word(packet, word_index)
            word_index += 1
            self.class_oui = read_uint_from_word(word, 0, 24)
            word = read_word(packet, word_index)
            word_index += 1
            self.information_class_code = read_uint_from_word(word, 16, 16)
            self.packet_class_code = read_uint_from_word(word, 0, 16)

        # Read the integer timestamp field if it exists
        integer_timestamp = 0
        if self.tsi != 0:
            if word_index + 1 > nwords:
                return 0
            integer_timestamp = read_word(packet, word_index)
            word_index += 1

        # Read the fractional timestamp field if it exists
        fractional_timestamp = 0
        if self.tsf != 0:
            if word_index + 1 > nwords:
                return 0
            upper = read_word(packet, word_index)
            word_index += 1
            lower = read_word(packet, word_index)
            word_index += 1
            fractional_timestamp = (upper << 32) | lower

        # Sceptre only writes timestamps using UTC seconds (TSI=1) and Real-Time picoseconds fractional seconds (TSF=2).
        # If the timestamp is formatted like this, create a timestamp object based on the integer and fractional
        # timestamp values read from the header
        if self.tsi == 1 and self.tsf == 2:
            self.timestamp = Timestamp(integer_timestamp, float(fractional_timestamp) / 1e12)

        return word_index


class VRTContext:
    """
    Represents the contents of a VITA-49 context packet
    """
    sample_format: np.dtype
    """Sample data format"""
    sample_rate: float
    """Sample rate (Hz)"""
    rf_reference: float
    """RF reference (Hz): This corresponds the DC frequency of the IF data"""
    bandwidth: float
    """Bandwidth of signal (Hz).  This may be less than the sampled bandwidth due to filter rolloff."""
    gain1: float
    """First stage gain (dB)"""
    gain2: float
    """Second stage gain (dB)"""
    ephemeris: ECEFEphemeris | None
    """Ephemeris (ECEF)"""

    def __init__(self):
        self.sample_format = np.uint8
        self.sample_rate = 0.0
        self.rf_reference = 0.0
        self.bandwidth = 0.0
        self.gain1 = 0.0
        self.gain2 = 0.0
        self.ephemeris = None

    def parse(self, payload: bytes) -> int:
        """
        Parses a VITA-49 context packet payload from an array of bytes

        Sceptre only writes a small subset of the context fields available in VITA-49, so this functions only parses
        those supported fields.  However, it will properly skip over and discard other fields if they exist.

        :param payload: Array of bytes to read the context packet from (this should just be the payload of the VRT
                        packet, without the VRT header)
        :return: Number of bytes read from the input
        """
        word_index = 0
        nwords = len(payload) // 4

        if word_index + 1 > nwords:
            return 0

        # Read the Context Indicator Field.  This word contains flags for whether each context field is present.
        cif = read_word(payload, word_index)
        word_index += 1

        # Read each indicator flag from the context indicator(Table 7.1.5.1-1)
        cif_ctx_assoc = read_bit_from_word(cif, 8)
        cif_gps_ascii = read_bit_from_word(cif, 9)
        cif_eph_ref = read_bit_from_word(cif, 10)
        cif_rel_eph = read_bit_from_word(cif, 11)
        cif_ecef_eph = read_bit_from_word(cif, 12)
        cif_ins_geo = read_bit_from_word(cif, 13)
        cif_gps_geo = read_bit_from_word(cif, 14)
        cif_format = read_bit_from_word(cif, 15)
        cif_state_event = read_bit_from_word(cif, 16)
        cif_dev_id = read_bit_from_word(cif, 17)
        cif_temperature = read_bit_from_word(cif, 18)
        cif_time_cal_time = read_bit_from_word(cif, 19)
        cif_time_adj = read_bit_from_word(cif, 20)
        cif_sample_rate = read_bit_from_word(cif, 21)
        cif_over_range = read_bit_from_word(cif, 22)
        cif_gain = read_bit_from_word(cif, 23)
        cif_ref_level = read_bit_from_word(cif, 24)
        cif_band_offset = read_bit_from_word(cif, 25)
        cif_rf_offset = read_bit_from_word(cif, 26)
        cif_rf = read_bit_from_word(cif, 27)
        cif_if = read_bit_from_word(cif, 28)
        cif_bw = read_bit_from_word(cif, 29)
        cif_ref_point_id = read_bit_from_word(cif, 30)
        cif_change_id = read_bit_from_word(cif, 31)

        # Read or skip over each context field if it is present

        if cif_ref_point_id:
            word_index += 1

        if cif_bw:
            if word_index + 2 > nwords:
                return 0
            self.bandwidth = read_frequency(payload, word_index * 4)
            word_index += 2

        if cif_if:
            word_index += 2

        if cif_rf:
            if word_index + 2 > nwords:
                return 0
            self.rf_reference = read_frequency(payload, word_index * 4)
            word_index += 2

        if cif_rf_offset:
            word_index += 2

        if cif_band_offset:
            word_index += 2

        if cif_ref_level:
            word_index += 1

        if cif_gain:
            if word_index + 1 > nwords:
                return 0
            self.gain2 = read_i16(payload, word_index * 4) / 128.0
            self.gain1 = read_i16(payload, word_index * 4 + 2) / 128.0
            word_index += 1

        if cif_over_range:
            word_index += 1

        if cif_sample_rate:
            if word_index + 2 > nwords:
                return 0
            self.sample_rate = read_frequency(payload, word_index * 4)
            word_index += 2

        if cif_time_adj:
            word_index += 2

        if cif_time_cal_time:
            word_index += 1

        if cif_temperature:
            word_index += 1

        if cif_dev_id:
            word_index += 2

        if cif_state_event:
            word_index += 1

        if cif_format:
            if word_index + 2 > nwords:
                return 0
            word1 = read_word(payload, word_index)
            word2 = read_word(payload, word_index + 1)
            self.sample_format = parse_sample_format(word1, word2)
            word_index += 2

        if cif_gps_geo:
            word_index += 11

        if cif_ins_geo:
            word_index += 11

        if cif_ecef_eph:
            if word_index + 13 > nwords:
                return 0
            if self.ephemeris is None:
                self.ephemeris = ECEFEphemeris()
            self.ephemeris.parse(payload[word_index * 4 : (word_index + 13) * 4])
            word_index += 13

        if cif_rel_eph:
            word_index += 13

        if cif_eph_ref:
            word_index += 1

        if cif_gps_ascii:
            # GPS ASCII is variable length.  The number of words is stored in the second word of the field
            # (section 7.1.5.24).  Find the size and skip over it.
            word_count = read_word(payload, word_index + 1)
            if word_index + word_count > nwords:
                return 0
            word_index += word_count
            pass

        if cif_ctx_assoc:
            # Not supported.  Sceptre will not write this field and since it is the last field, so we don't need to
            # worry about figuring out its size.
            pass

        return word_index


class VRTPacket:
    """
    Represents a decoded VRT packet
    """
    header: VRTHeader
    """VRT header present in every VRT packet"""
    context: VRTContext | None
    """
    Associated context packet for this VRT packet (this context information was not necessarily contained in this 
    decoded VRT packet, but was received at some point in the past and associated with this stream
    """
    frame: IFFrame | SpectralFrame | None
    """
    Decoded frame of data (for spectral data, this frame may have been built up over a number of preceeding 
    packets
    """

    def __init__(self, header: VRTHeader, context: VRTContext | None, frame: IFFrame | SpectralFrame | None):
        self.header = header
        self.context = context
        self.frame = frame


class VITA49Decoder:
    """
    VITA-49 VRT packet decoder

    This object decodes VRT packets to obtain IF or spectral data and metadata.
    """
    raw_output: bool
    """
    If true, data will be output in the same format in which it was received, otherwise it will be converted to 
    floating point
    """
    contexts: dict
    """Mapping of stream ID to current context packet"""
    counters: dict
    """Mapping of stream ID to the most recent packet counter received (used to detect packet loss)"""
    drop_count: int
    """Estimated number of packets that have been dropped"""
    spec_frame: np.ndarray | None
    """Buffer used to build up complete spectral frame from multiple input packets"""
    spec_index: int
    """Current index where data is expected to be written next to spec_frame"""
    warn_interval: float = 1.0
    """How often to warn the user about packet drops (seconds)"""
    last_warn_drop_count: int
    """The packet drop count when the user was last warned about packet loss"""
    last_warn_drop_time: float
    """The time the user was last warned about packet loss"""

    def __init__(self, raw_output: bool = False):
        """
        Constructs a VITA49Decoder object

        :param raw_output: whether to return data in the raw format contained in the VRT packets.  If false, data is
                           automatically converted to floating point and scaled according to the gains found in the
                           context fields.
        """
        self.raw_output = raw_output
        self.contexts = {}
        self.counters = {}
        self.drop_count = 0
        self.spec_frame = None
        self.spec_index = 0
        self.last_warn_drop_count = 0
        self.last_warn_drop_time = 0.0

    def decode_packet(self, packet: bytes) -> VRTPacket | None:
        """
        Decodes a single VITA-49 packet.

        The packet can contain just a VRT packet, or may also contain the VRL layer.  If VRL is present, it is simply
        discarded and the VRT packet is decoded (VRL is not used for anything)

        :param packet: Array of bytes in packet
        :return: Decoded VRTPacket or None if packet could not be successfully decoded for some reason
        """

        # The minimum valid VRT packet is 16 bytes (4 words), so just return early if we don't have that
        if len(packet) < 16:
            return None

        # If the first word of the packet is ASCII "VRLP", strip the VRL layer off. (See VITA-49.1)
        word = read_word(packet, 0)
        if word == 0x56524C50:
            packet = packet[8:-4]

        # Decode the VRT packet
        return self.decode_vrt_packet(packet)

    def decode_vrt_packet(self, packet: bytes) -> VRTPacket | None:
        """
        Decodes a single VITA-49 VRT packet.

        :param packet: Array of bytes in packet
        :return: Decoded VRTPacket or None if packet could not be successfully decoded for some reason
        """
        # Parse the main VRT header present in every packet
        hdr = VRTHeader()
        word_index = hdr.parse(packet)
        # If parsing the header returns 0, it indicates an error, so just return None
        if word_index == 0:
            return None

        # Slice buffer down to just the payload bytes (slice off the header and trailer if it exists)
        payload_words = hdr.packet_size - word_index
        if hdr.has_trailer:
            # The trailer is always just a single word at the end of the packet
            payload = packet[(word_index * 4):-4]
            payload_words -= 1
        else:
            payload = packet[word_index * 4:]

        # If the length of the packet payload is smaller than the packet size (in words), this packet is invalid.
        if len(payload) < payload_words  * 4:
            return None

        # Check packet counter for drops.  Counters are continuous for each stream ID.
        if hdr.packet_type <= 3:
            if hdr.stream_id in self.counters:
                # Calculate the expected counter value
                expected_count = (self.counters[hdr.stream_id] + 1) % 16
                # If the counter in the header does not match the expected, estimate how many packets were dropped.
                # Since the counter is only 4-bits, it is hard to estimate exactly how many packets were dropped since
                # dropping 17 packets looks just like dropping 1 packet with respect to the counter.
                if hdr.counter != expected_count:
                    if hdr.counter > expected_count:
                        self.drop_count += (hdr.counter - expected_count)
                    else:
                        self.drop_count += (16 - expected_count + hdr.counter)
            # Update the most recent counter for this stream ID
            self.counters[hdr.stream_id] = hdr.counter

        # Print a warning if packet loss was detected.
        self._warn_on_packet_loss()

        # Parse the packet payload based on the packet type.  Sceptre only sends the following packet types:
        #  1. IF Data packet with stream ID: Time-domain IQ or real-valued samples
        #  3. Extension data packet with stream ID: Spectral frame fragment
        #  4: IF context packet: Contains additional metadata about the data stream
        if hdr.packet_type == 1:
            # IF data packet.
            # Only parse the data if we have a context for this stream ID since we need to know data format
            if hdr.stream_id in self.contexts:
                ctx = self.contexts[hdr.stream_id]
                # Create an IFFrame and fill it with the proper metadata from the header and context
                output = IFFrame()
                output.sample_delta = 1.0 / ctx.sample_rate
                output.timestamp = hdr.timestamp
                output.center_frequency = ctx.rf_reference
                output.bandwidth = ctx.bandwidth
                # Create a NumPy array from the packet payload based on the data format specified in the context
                output.data = np.frombuffer(payload, dtype=ctx.sample_format).byteswap()
                # If the user has not requested raw output, convert the samples to floating point and scale the data
                # based on the gains in the context.
                if not self.raw_output:
                    # Set the gain value in the output IFFrame to zero if we are going to apply it here
                    output.gain_dB = 0.0
                    total_gain = ctx.gain1 + ctx.gain2
                    output.data = convert_to_float(output.data)
                    # Gain is in dB, convert to absolute
                    scale = 10**(-total_gain / 20.0)
                    if scale != 1.0:
                        output.data *= scale
                else:
                    # Return the gain in the output IFFrame when returning raw samples so the user can apply it if
                    # needed.
                    output.gain_dB = ctx.gain1 + ctx.gain2
                return VRTPacket(hdr, ctx, output)

        elif hdr.packet_type == 3:
            # Extension data packet.  Sceptre uses its own custom extension packets for spectral data.  VITA-49.2
            # introduced the ability to stream spectral data, but this has not been added to Sceptre yet.  In most
            # cases, an entire SpectralFrame does not fit within a single VRT packet, so each VRT packet just contains
            # a "fragment" of a spectral frame.  These fragments are then collected to build up the complete frame.
            #
            # For spectral extension packets, sceptre embeds a class OUI of "0x3db3db" in the VRT header, so check for
            # that class OUI before attempting to decode this as a spectral frame fragment.

            # Only decode the packet if we have received a context packet for this stream since we need data from the
            # context to properly decode it.
            if hdr.stream_id in self.contexts and hdr.has_class_id and hdr.class_oui == 0x3db3db:
                ctx = self.contexts[hdr.stream_id]
                # Attempt to decode a spectral frame from the payload.  This will only return a non-None value if this
                # is the last fragment and a frame is now complete.  Otherwise, just return the header and context to
                # indicate the packet was decoded, but did not produce any data yet.
                spec_frame = self._decode_spectrum_fragment(hdr, ctx, payload)
                return VRTPacket(hdr, ctx, spec_frame)

        elif hdr.packet_type == 4:
            # IF context packet
            # Create and parse the context packet payload
            ctx = VRTContext()
            ctx.parse(payload)
            # Save off the context packet associated with the stream ID of this VRT packet.  VITA-49 prescribes several
            # ways to link context packet streams to IF data streams, but Sceptre always uses the simple method of using
            # the same stream ID for the context packets and the data streams they are associated with.
            self.contexts[hdr.stream_id] = ctx
            return VRTPacket(hdr, ctx, None)

        return VRTPacket(hdr, None, None)

    def _decode_spectrum_fragment(self, hdr: VRTHeader, ctx: VRTContext, payload: bytes) -> SpectralFrame | None:
        """
        Decodes a VRT packet payload that contains a spectral frame fragment

        :param hdr: VRT header of the packet
        :param ctx: Most recent context packet associated with this VRT packet
        :param payload: Payload of the VRT packet (not including header or trailer)
        :return: if this fragment completes a spectral frame, returns a SpectralFrame, otherwise returns None
        """
        # Spectrum fragments contain 28 bytes of header.  Make sure we have at least that much payload before parsing.
        if len(payload) < 28:
            return None

        # Sceptre embeds several metadata fields at the beginning of each spectral fragment.  Read those parameters.
        # Start frequency of the resulting frame (this should be the same for all fragments from a single frame)
        freq_start = read_f64(payload, 0)
        # frequency delta between each spectral bin
        freq_delta = read_f64(payload, 8)
        # Overall size of the spectral frame (this should be the same for all fragments from the frame)
        frame_size = read_u32(payload, 16)
        # Start bin index of this fragment (this indicates where in the spectral frame this fragment fits)
        start_bin = read_u32(payload, 20)
        # Number of spectral bins in this fragment
        nbins = read_u32(payload, 24)

        # Determine output format (use the format from the context if we are outputting raw, otherwise float32)
        sample_format = np.dtype(np.float32)
        if self.raw_output:
            sample_format = ctx.sample_format

        # If this is our first fragment, the spectral frame size has changed, or the data format has changed, resize our
        # spectral frame buffer and reset the index since we want to ensure that we start the new frame from scratch.
        if self.spec_frame is None or len(self.spec_frame) != frame_size or sample_format != self.spec_frame.dtype:
            self.spec_frame = np.zeros(frame_size, dtype=sample_format)

        # If this fragment is not where we expect to receive next, discard it and wait until we can start the next
        # frame.  This will occur if we either pick up the stream in the middle of a frame or if we drop packet/s in
        # the middle of a frame.  We only want to output fully received spectral frames.
        if start_bin != self.spec_index:
            self.spec_index = 0
            return

        # Bytes per element
        bpe = ctx.sample_format.itemsize
        # Bytes in the fragment
        fragment_bytes = nbins * bpe
        # Make sure there are enough bytes in the payload to cover the reported number of bins
        if fragment_bytes > len(payload) - 28:
            return
        # Create a NumPy array from the payload based on the format
        data = np.frombuffer(payload[28:28 + fragment_bytes], dtype=ctx.sample_format).byteswap()
        if not self.raw_output:
            data = convert_to_float(data)
        # Insert the fragment into the current spectral frame we are working on
        self.spec_frame[start_bin:start_bin + nbins] = data
        # Update the index we expect the next fragment to be written to
        self.spec_index = start_bin + nbins

        # If this fragment completes a frame, return it
        if start_bin + nbins == frame_size:
            # Create and fill the output spectral frame object
            frame = SpectralFrame()
            frame.data = self.spec_frame
            frame.timestamp = hdr.timestamp
            frame.frequency_start = freq_start
            frame.frequency_delta = freq_delta
            frame.frame_delta = 1.0 / ctx.sample_rate
            # Reset the spectral frame index so we will start a new frame on the next call.
            self.spec_index = 0
            return frame

        return None

    def _warn_on_packet_loss(self):
        """
        Warns the user periodically if packet loss is detected
        """
        # Only print a warning at the maximum rate defined by self.warn_interval.  If we print a warning for every drop
        # it could use significant CPU and further exacerbate the problem.
        t = time.time()
        if self.drop_count != self.last_warn_drop_count and t - self.last_warn_drop_time > self.warn_interval:
            print(f'packet loss detected: drops={self.drop_count}')
            self.last_warn_drop_time = t
            self.last_warn_drop_count = self.drop_count
