Android 中以音频剪辑作为输入的音频分类(使用 YAMNet TensorFlow lite 模型)

问题描述 投票:0回答:1

我用过这个音频分类应用程序,我真的很喜欢它: https://github.com/tensorflow/examples/tree/master/lite/examples/audio_classification/android/

此示例使用手机麦克风开始录音(现场录音)并对输入进行分类。

这是使用该模型的代码:

val model = YamnetClassification.newInstance(this)

// Creates inputs for reference.
val audioClip = TensorBuffer.createFixedSize(intArrayOf(15600), DataType.FLOAT32)
audioClip.loadBuffer(byteBuffer)

// Runs model inference and gets result.
val outputs = model.process(audioClip)
val scores = outputs.scoresAsTensorBuffer

我有一个 音频剪辑(例如,在 R.raw.s1.wav 中),我想将其作为输入并在 YAMNet 模型的帮助下对其进行分类。

如何更新上面的示例以将音频剪辑作为输入而不是现场录音?我认为音频剪辑需要转换为ByteBuffer?如何为音频剪辑完成此转换?

android audio deep-learning tensorflow-lite
1个回答
0
投票

与 YamNet 文档中的 python 代码相同(https://www.kaggle.com/models/google/yamnet/frameworks/tensorFlow2/variations/yamnet/versions/1?tfhub-redirect=true):

public static void mp3ToPcm(Context context, File mp3File, Mp3ToPcmListener listener) {
        try {
            Log.i(TAG, "TFL::predictMp3, mp3ToPcm");

            String pcmUri = "/data/data/" + context.getPackageName() + "/" + mp3File.getName() + ".pcm";
            boolean createRes = new File(pcmUri).createNewFile();

            String mp3Uri = mp3File.getAbsolutePath();
            int res = WcvCodec.decodeMp3ToPcm(mp3Uri, pcmUri);

            if (res == 0) {
                if (listener != null) {
                    listener.onFile(pcmUri);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            Log.e(TAG, "TFL::predictMp3, mp3ToPcm, Exception = " + e);
        }
    }



public static void decodePcm(String pcmUri, DecodePcmListener listener) {
        File pcmFile = null;
        FileInputStream pcmFis = null;

        try {
            //Log.d(TAG, "TFL::predictMp3, decodePcm");

            List<PredictPcmRes> pcmResList = new ArrayList<>();
            //Log.d(TAG, "TFL::predictMp3, decodePcm, new pcmResList");

            pcmFile = new File(pcmUri);
            pcmFis = new FileInputStream(pcmFile);

            String pcmName = pcmFile.getName();
            pcmName = pcmName.substring(MathUtils.clamp(pcmName.length() - 15, 0, pcmName.length()), pcmName.length());
            Log.i(WcvCodec.TAG, "decodePcm, pcmName = " + pcmName);

            long totalBytes = pcmFile.length();
            Short[] pcmShorts = new Short[SAMPLE_SIZE];

            long sliceCnt = totalBytes / DATA_CAPACITY;
            int sliceIdx = 0;
            Log.i(WcvCodec.TAG, "decodePcm"
                    + ", sliceCnt = " + sliceCnt
                    + "= res.totalBytes / 4 * 8000L = "
                    + totalBytes + " / " + (4 * 8000L)
            );

            int accRead = 0;

            byte[] twoBytes = new byte[2];
            for (int idx = 0, idxForSeconds = 0, read; (read = pcmFis.read(twoBytes, 0, 2)) != -1; idx += read) {

                accRead += read;

                if (idx <= 10 /*|| idx > 1269680*/) {
                    Log.i(WcvCodec.TAG, "decodePcm"
                            + ", idx=" + idx
                            + ", totalBytes=" + totalBytes
                            + ", pcmName=" + pcmName
                    );
                }

                // order with LITTLE_ENDIAN
                ByteBuffer byteBuffer = ByteBuffer.wrap(twoBytes);
                ByteBuffer orderedBuffer = byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
                short orderedShort = orderedBuffer.getShort();

                // left channel only
                if (idx % 4 == 0) {
                    pcmShorts[idxForSeconds++ % pcmShorts.length] = orderedShort;

                    if (BuildConfig.DEBUG) {
                        if (idx >= 6000 && idx < 6020) {
                            Log.i(WcvCodec.TAG, "decodePcm"
                                    + ", pcmName=" + pcmName
                                    + ", idx=" + idx
                                    + ", [0x" + Integer.toHexString(twoBytes[0] & 0xff)
                                    + " 0x" + Integer.toHexString(twoBytes[1] & 0xff)
                                    + "] = " + orderedShort
                            );
                        }
                    }

                    //Log.i(WcvCodec.TAG, "decodePcm" + ", idxForSeconds=" + idxForSeconds);
                    if (idxForSeconds % frame_of_audio == 0) {
                        sliceIdx += 1;

                        if (listener != null) {

                            PredictPcmRes res = new PredictPcmRes();
                            res.pcmUri = pcmUri;
                            res.pcmShorts = pcmShorts;
                            res.durIdx = sliceIdx;
                            res.durCnt = sliceCnt;
                            Log.d(TAG, "TFL::predictMp3, decodePcm, " + res.durIdx + " / " + res.durCnt);

                            listener.run(res, pcmResList);
                        }
                    }
                }
            }

            Log.i(WcvCodec.TAG, "decodePcm"
                    + ", sliceIdx=" + sliceIdx
                    + ", sliceCnt=" + sliceCnt
                    + ", accRead=" + accRead
            );
        } catch (Exception e) {
            e.printStackTrace();
            Log.d(TAG, "TFL::predictMp3, Exception = " + e);
        } finally {
            if (pcmFis != null) {
                ReleaseUtil.release(pcmFis);
            }
            if (pcmFile != null) {
                boolean res = pcmFile.delete();
            }
        }
    }



    public static void predictPcm(PredictPcmRes res, Context context, PredictPcmListener listener) {
        try {
            Log.d(TAG, "TFL::predictPcm, pcm length = " + res.pcmShorts.length);

            // pcm to normalized floats
            float[] normalizedFloats = new float[res.pcmShorts.length];
            for (int i = 0; i < res.pcmShorts.length; i++) {
                short ps = res.pcmShorts[i] != null ? res.pcmShorts[i] : 0;
                normalizedFloats[i] = 1f * ps / (Short.MAX_VALUE + 1);
            }
            res.normalizedFloats = normalizedFloats;
            Log.d(TAG, "TFL::predictPcm, normalizedFloats length = " + res.normalizedFloats.length);

            // aurora
            AuroraHelper ah = new AuroraHelper();
            ah.PytorchModel(context);
            ah.runInference(res.normalizedFloats);

            // do predict on normalized floats
            ModelExecutor executor = new ModelExecutor(context, false);
            Pair<ArrayList<String>, ArrayList<Float>> pair = executor.execute(normalizedFloats);
            res.pair = pair;

            if (Lg.DEBUG) {
                Log.i(TAG, "TFL::predictPcm, pair first = " + pair.getFirst().toString());
                Log.i(TAG, "TFL::predictPcm, pair second = " + pair.getSecond().toString());
            }

            if (listener != null) {
                listener.onSuccess();
            }

        } catch (Exception e) {
            e.printStackTrace();
            Log.d(TAG, "TFL::predictMp3, predictPcm, Exception = " + e);
        }
    }
© www.soinside.com 2019 - 2024. All rights reserved.