#!/usr/bin/env python
# coding: utf-8

# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.

import os
import argparse
import logging
import sys
import requests
import time
import swagger_client
import re

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG,
                    format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p %Z")

API_VERSION = "2024-11-15"

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run app.py with custom parameters.")
parser.add_argument(
    "--service_key",
    type=str,
    required=True,
    help="The subscription key if the speech service."
)
parser.add_argument(
    "--service_region",
    type=str,
    required=True,
    help="The region for the speech service."
)
parser.add_argument(
    "--recordings_blob_uris",
    type=str,
    required=False,
    help="SAS URI pointing to audio files stored in Azure Blob Storage."
)
parser.add_argument(
    "--recordings_container_uri",
    type=str,
    required=False,
    help="SAS URI pointing to a container in Azure Blob Storage."
)
parser.add_argument(
    "--locale",
    type=str,
    required=True,
    help="The locale of the input audio file."
)
args = parser.parse_args()

# Initialize speech recognition engine
SUBSCRIPTION_KEY = args.service_key
SERVICE_REGION = args.service_region

# Use user-provided or default recordings_blob_uri
RECORDINGS_BLOB_URIS = args.recordings_blob_uris
LOCALE = args.locale

NAME = "Simple transcription"
DESCRIPTION = "Simple transcription description"

# Provide the uri of a container with audio files for transcribing all of them
# with a single request. At least 'read' and 'list' (rl) permissions are required.
RECORDINGS_CONTAINER_URI = args.recordings_container_uri

# Set model information when doing transcription with custom models
MODEL_REFERENCE = None  # guid of a custom model


def transcribe_from_single_blob(uris, properties):
    """
    Transcribe batch audio files located at `uris` using the settings specified in `properties`
    using the base model for the specified locale.
    """
    transcription_definition = swagger_client.Transcription(
        display_name=NAME,
        description=DESCRIPTION,
        locale=LOCALE,
        content_urls=[uri.strip() for uri in uris.split(",")],
        properties=properties
    )

    return transcription_definition


def transcribe_with_custom_model(client, uris, properties):
    """
    Transcribe batch audio files located at `uris` using the settings specified in `properties`
    using the base model for the specified locale.
    """
    # Model information (ADAPTED_ACOUSTIC_ID and ADAPTED_LANGUAGE_ID) must be set above.
    if MODEL_REFERENCE is None:
        logging.error("Custom model ids must be set when using custom models")
        sys.exit()

    model = {'self': f'{client.configuration.host}/models/{MODEL_REFERENCE}'}

    transcription_definition = swagger_client.Transcription(
        display_name=NAME,
        description=DESCRIPTION,
        locale=LOCALE,
        content_urls=[uri.strip() for uri in uris.split(",")],
        model=model,
        properties=properties
    )

    return transcription_definition


def transcribe_from_container(uri, properties):
    """
    Transcribe all files in the container located at `uri` using the settings specified in `properties`
    using the base model for the specified locale.
    """
    transcription_definition = swagger_client.Transcription(
        display_name=NAME,
        description=DESCRIPTION,
        locale=LOCALE,
        content_container_url=uri,
        properties=properties
    )

    return transcription_definition


def _paginate(api, paginated_object):
    """
    The autogenerated client does not support pagination. This function returns a generator over
    all items of the array that the paginated object `paginated_object` is part of.
    """
    yield from paginated_object.values
    typename = type(paginated_object).__name__
    auth_settings = ["api_key"]
    while paginated_object.next_link:
        link = paginated_object.next_link[len(api.api_client.configuration.host):]
        paginated_object, status, headers = api.api_client.call_api(
            link, "GET", response_type=typename, auth_settings=auth_settings)

        if status == 200:
            yield from paginated_object.values
        else:
            raise Exception(f"could not receive paginated data: status {status}")


def delete_all_transcriptions(api):
    """
    Delete all transcriptions associated with your speech resource.
    """
    logging.info("Deleting all existing completed transcriptions.")

    # get all transcriptions for the subscription
    transcriptions = list(_paginate(api, api.get_transcriptions()))

    # Delete all pre-existing completed transcriptions.
    # If transcriptions are still running or not started, they will not be deleted.
    for transcription in transcriptions:
        transcription_id = transcription._self.split('/')[-1]
        logging.debug(f"Deleting transcription with id {transcription_id}")
        try:
            api.delete_transcription(transcription_id)
        except swagger_client.rest.ApiException as exc:
            logging.error(f"Could not delete transcription {transcription_id}: {exc}")


def transcribe():
    logging.info("Starting transcription client...")

    # configure API key authorization: subscription_key
    configuration = swagger_client.Configuration()
    configuration.api_key["Ocp-Apim-Subscription-Key"] = SUBSCRIPTION_KEY
    configuration.host = f"https://{SERVICE_REGION}.api.cognitive.microsoft.com/speechtotext"

    # create the client object and authenticate
    client = swagger_client.ApiClient(configuration)

    # create an instance of the transcription api class
    api = swagger_client.CustomSpeechTranscriptionsApi(api_client=client)

    # Specify transcription properties by passing a dict to the properties parameter. See
    # https://learn.microsoft.com/azure/cognitive-services/speech-service/batch-transcription-create?pivots=rest-api#request-configuration-options # noqa: E501
    # for supported parameters.
    properties = swagger_client.TranscriptionProperties(time_to_live_hours=6)
    # properties.word_level_timestamps_enabled = True
    # properties.display_form_word_level_timestamps_enabled = True
    # properties.punctuation_mode = "DictatedAndAutomatic"
    # properties.profanity_filter_mode = "Masked"
    # properties.destination_container_url = "<SAS Uri with at least write (w) permissions>"
    # # The container where results should be written to

    # uncomment the following block to enable and configure speaker separation
    # properties.diarization = swagger_client.DiarizationProperties(max_speakers=5, enabled=True)

    # uncomment the following block to enable and configure language identification prior to transcription.
    # Available modes are "single" and "continuous".
    # properties.language_identification = swagger_client.LanguageIdentificationProperties(
    #     mode="single", candidate_locales=["en-US", "ja-JP"])

    # Use base models for transcription. Comment this block if you are using a custom model.
    transcription_definition = transcribe_from_single_blob(RECORDINGS_BLOB_URIS, properties)

    # Uncomment this block to use custom models for transcription.
    # transcription_definition = transcribe_with_custom_model(client, RECORDINGS_BLOB_URIS, properties)

    # Uncomment this block to transcribe all files from a container.
    # transcription_definition = transcribe_from_container(RECORDINGS_CONTAINER_URI, properties)

    created_transcription, status, headers = api.transcriptions_submit_with_http_info(
        transcription=transcription_definition, api_version=API_VERSION)

    # get the transcription Id from the location URI
    transcription_id = headers["location"].split("/")[-1].split("?")[0]

    # Log information about the created transcription. If you should ask for support, please
    # include this information.
    logging.info(f"Created new transcription with id '{transcription_id}' in region {SERVICE_REGION}")

    logging.info("Checking status.")

    completed = False

    while not completed:
        # wait for 5 seconds before refreshing the transcription status
        time.sleep(5)

        transcription = api.transcriptions_get(transcription_id, api_version=API_VERSION)
        logging.info(f"Transcriptions status: {transcription.status}")

        if transcription.status in ("Failed", "Succeeded"):
            completed = True

        if transcription.status == "Succeeded":
            if properties.destination_container_url is not None:
                logging.info("Transcription succeeded. Results are located in your Azure Blob Storage.")
                break

            # download results
            os.makedirs('results', exist_ok=True)

            pag_files = api.transcriptions_list_files(transcription_id, api_version=API_VERSION)
            for file_data in _paginate(api, pag_files):
                if file_data.kind != "Transcription":
                    continue

                audiofilename = file_data.name
                results_url = file_data.links.content_url
                results = requests.get(results_url)
                result_text = results.content.decode('utf-8')

                # save results to file
                safe_audiofilename = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', audiofilename)
                result_file_path = os.path.join('results', safe_audiofilename)
                with open(result_file_path, 'w', encoding='utf-8') as f:
                    f.write(result_text)
                    logging.info(f"Results saved to {result_file_path}")
        elif transcription.status == "Failed":
            logging.info(f"Transcription failed: {transcription.properties.error.message}")


if __name__ == "__main__":
    transcribe()
