import numpy as np
import copy
from .common import IFFrame
from .common import SpectralFrame
from .timestamp import Timestamp
import math


class PSDEngine:
    """
    Calculates a PSD (spectrum) stream from an IF (time-domain) data stream

    The general process for producing the PSD is as follows:
        1. Buffer up `nfft` input samples
        2. Apply time-domain window (currently "hanning")
        3. Perform FFT
        4. Shift DC bin to center (fftshift)
        5. Calculate mag^2 of frame
        6. Average mag^2 spectral frames
        7. Convert mag^2 average to dB
    """
    nfft: int
    """FFT size"""
    navg: int
    """Number of spectral frames to average on output"""
    input_frame: np.ndarray | None
    """Input frame buffer"""
    input_timestamp: Timestamp
    """Timestamp of the start of the input frame buffer"""
    input_delta: float
    """Sample delta of the data in the input frame buffer"""
    input_fc: float
    """Center frequency of the input frame buffer"""
    input_cpx: bool
    """Whether or not input data is complex"""
    input_size: int
    """Current used size of the input frame buffer"""
    window: np.ndarray
    """Time-domain window to apply to the input frame"""
    output_accum: np.ndarray | None
    """Accumulator for calculating output frame average"""
    output_count: int
    """Number of output frames accumulated in output_accum"""
    output_start_timestamp: Timestamp
    """Timestamp of the first time sample used to produce the current output frame being worked on"""
    output_delta: float
    """Time delta between each output frame"""
    output_fc: float
    """Center frequency of the output frame"""
    output_frames: list[SpectralFrame]
    """List of frames to output (for small FFT sizes, we may have multiple frames produced by a single frame of input"""

    def __init__(self, nfft: int = 4096, navg: int = 10):
        """
        Constructs a PSDEngine object

        :param nfft: FFT size.  This will control how many frequency bins are output as well as the tradeoff between
                     time and frequency resolution.
        :param navg: Number of output frames to average.
        """
        self.nfft = nfft
        self.navg = navg
        self.input_frame = None
        self.input_timestamp = Timestamp()
        self.input_delta = 1.0
        self.input_size = 0
        self.input_cpx = True
        self.window = np.hanning(self.nfft)
        self.window /= np.sum(self.window)
        self.output_accum = None
        self.output_count = 0
        self.output_start_timestamp = Timestamp()
        self.output_delta = 1.0
        self.output_frames = []

    def process(self, if_frame: IFFrame) -> list[SpectralFrame]:
        """
        Process an IFFrame (time-domain) to produce spectral PSD frames

        The processed IFFrames do not have to have any particular relation to the FFT or averaging size, therefore one
        call to process may produce multiple output frames, or no frames at all.  For instance, if the FFT size is
        16K, navg is 10, and the input IFFrames are only 1K samples, it will take roughly 160 calls to process to
        produce a single output frame.  If FFT size is 64, navg is 1, and IFFrames are 1K, one call to process() might
        produce 16 spectral frames.

        :param if_frame: Time-domain input
        :return: List of frequency-domain output frames
        """
        # If we haven't already initialized the input frame buffer, initialize it to zeros
        if self.input_frame is None:
            self.input_frame = np.zeros(self.nfft, dtype=if_frame.data.dtype)

        # Insert the input frame into the correct spot in the input buffer.  If we are able to fill the buffer, process
        # the frame and start building the next frame until the input is all used.
        # index keeps track of where we are in the input
        index = 0
        # Number of input samples
        nin = len(if_frame.data)
        # Loop while we still have input left
        while index < nin:
            # If this is going to be the first data in this fft window, latch the timestamp and other metadata
            if self.input_size == 0:
                self.input_timestamp = if_frame.timestamp + Timestamp(if_frame.sample_delta * index)
                self.input_delta = if_frame.sample_delta
                self.input_fc = if_frame.center_frequency
                self.input_cpx = np.iscomplexobj(if_frame.data)

            # Calculate the number of samples we have left in the input
            nleft = nin - index
            # Calculate the number of samples we need to fill the input buffer
            needed = self.nfft - self.input_size
            if nleft < needed:
                # We don't have enough input to fill the input buffer, just copy the input into the input buffer and
                # update the indexing
                self.input_frame[self.input_size:self.input_size + nleft] = if_frame.data[index:]
                index = nin
                self.input_size += nleft
            else:
                # We have enough input to fill the rest of the input buffer.  Copy the input into the input buffer and
                # then process the frame.
                self.input_frame[self.input_size:] = if_frame.data[index:index + needed]
                index += needed
                self._process_input_frame()
                self.input_size = 0

        # Make a deep copy of the output frames since we are going to clear them (is there a better way to do this?)
        output_frames = copy.deepcopy(self.output_frames)
        # Clear the output frame buffer since we are returning them
        self.output_frames.clear()
        return output_frames

    def _process_input_frame(self):
        # Perform windowing, FFT and mag^2
        if self.input_cpx:
            spec = np.abs(np.fft.fftshift(np.fft.fft(self.window * self.input_frame)))
        else:
            spec = np.abs(np.fft.rfft(self.window * self.input_frame))
        spec_mag2 = spec * spec

        # If this is the first FFT frame to average, latch the timestamp and other metadata of the window
        if self.output_count == 0:
            self.output_start_timestamp = self.input_timestamp
            self.output_delta = self.input_delta
            self.output_fc = self.input_fc

        # Accumulate the mag^2 frame in and bump the count
        if self.output_accum is None:
            self._reset_output_accum()
        self.output_accum += spec_mag2
        self.output_count += 1

        # If we have accumulated navg FFT frames, produce an averaged output frame.
        if self.output_count >= self.navg:
            spec_frame = SpectralFrame()
            # Set the output frame timestamp to the center of the time window of samples that contributed to the frame
            spec_frame.timestamp = self.output_start_timestamp + Timestamp(0.5 * self.output_count * self.nfft * self.output_delta)
            if self.input_cpx:
                spec_frame.frequency_start = self.output_fc - (math.floor(self.nfft / 2)) / (self.output_delta * self.nfft)
            else:
                spec_frame.frequency_start = self.output_fc
            spec_frame.frequency_delta = 1.0 / (self.output_delta * self.nfft)
            # Convert output frame accumulator (mag^2) to averaged dB
            spec_frame.data = 10 * np.log10(np.abs(self.output_accum / self.output_count))
            # Push the output frame to the output list
            self.output_frames.append(spec_frame)
            # Reset the output frame accumulator
            self._reset_output_accum()

    def _reset_output_accum(self):
        # Reset the output frame accumulator for averaging
        # Determine the number of output bins based on whether the input is real or complex
        nfft_out = self.nfft
        if not self.input_cpx:
            nfft_out = self.nfft // 2 + 1
        self.output_accum = np.zeros(nfft_out, dtype=np.float32)
        self.output_count = 0
