import contextlib
import io
import wave
from collections.abc import Iterable

import numpy
import torch


def frames2wav(resulting_array: Iterable[int], sample_rate: int) -> bytes:
    res_io_stream = io.BytesIO()

    with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(sample_rate)
        wf.writeframes(resulting_array)

    res_io_stream.seek(0)

    return res_io_stream.read()


def ndarray2wav(resulting_array: numpy.ndarray, sample_rate: int) -> bytes:
    return frames2wav((resulting_array * 32767).astype('int16'), sample_rate)


def tensor2wav(resulting_array: torch.Tensor, sample_rate: int) -> bytes:
    return frames2wav((resulting_array * 32767).to(torch.int16).numpy(), sample_rate)