我用过这个音频分类应用程序,我真的很喜欢它: https://github.com/tensorflow/examples/tree/master/lite/examples/audio_classification/android/
此示例使用手机麦克风开始录音(现场录音)并对输入进行分类。
这是使用该模型的代码:
val model = YamnetClassification.newInstance(this)
// Creates inputs for reference.
val audioClip = TensorBuffer.createFixedSize(intArrayOf(15600), DataType.FLOAT32)
audioClip.loadBuffer(byteBuffer)
// Runs model inference and gets result.
val outputs = model.process(audioClip)
val scores = outputs.scoresAsTensorBuffer
我有一个 音频剪辑(例如,在 R.raw.s1.wav 中),我想将其作为输入并在 YAMNet 模型的帮助下对其进行分类。
如何更新上面的示例以将音频剪辑作为输入而不是现场录音?我认为音频剪辑需要转换为ByteBuffer?如何为音频剪辑完成此转换?
与 YamNet 文档中的 python 代码相同(https://www.kaggle.com/models/google/yamnet/frameworks/tensorFlow2/variations/yamnet/versions/1?tfhub-redirect=true):
public static void mp3ToPcm(Context context, File mp3File, Mp3ToPcmListener listener) {
try {
Log.i(TAG, "TFL::predictMp3, mp3ToPcm");
String pcmUri = "/data/data/" + context.getPackageName() + "/" + mp3File.getName() + ".pcm";
boolean createRes = new File(pcmUri).createNewFile();
String mp3Uri = mp3File.getAbsolutePath();
int res = WcvCodec.decodeMp3ToPcm(mp3Uri, pcmUri);
if (res == 0) {
if (listener != null) {
listener.onFile(pcmUri);
}
}
} catch (Exception e) {
e.printStackTrace();
Log.e(TAG, "TFL::predictMp3, mp3ToPcm, Exception = " + e);
}
}
public static void decodePcm(String pcmUri, DecodePcmListener listener) {
File pcmFile = null;
FileInputStream pcmFis = null;
try {
//Log.d(TAG, "TFL::predictMp3, decodePcm");
List<PredictPcmRes> pcmResList = new ArrayList<>();
//Log.d(TAG, "TFL::predictMp3, decodePcm, new pcmResList");
pcmFile = new File(pcmUri);
pcmFis = new FileInputStream(pcmFile);
String pcmName = pcmFile.getName();
pcmName = pcmName.substring(MathUtils.clamp(pcmName.length() - 15, 0, pcmName.length()), pcmName.length());
Log.i(WcvCodec.TAG, "decodePcm, pcmName = " + pcmName);
long totalBytes = pcmFile.length();
Short[] pcmShorts = new Short[SAMPLE_SIZE];
long sliceCnt = totalBytes / DATA_CAPACITY;
int sliceIdx = 0;
Log.i(WcvCodec.TAG, "decodePcm"
+ ", sliceCnt = " + sliceCnt
+ "= res.totalBytes / 4 * 8000L = "
+ totalBytes + " / " + (4 * 8000L)
);
int accRead = 0;
byte[] twoBytes = new byte[2];
for (int idx = 0, idxForSeconds = 0, read; (read = pcmFis.read(twoBytes, 0, 2)) != -1; idx += read) {
accRead += read;
if (idx <= 10 /*|| idx > 1269680*/) {
Log.i(WcvCodec.TAG, "decodePcm"
+ ", idx=" + idx
+ ", totalBytes=" + totalBytes
+ ", pcmName=" + pcmName
);
}
// order with LITTLE_ENDIAN
ByteBuffer byteBuffer = ByteBuffer.wrap(twoBytes);
ByteBuffer orderedBuffer = byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
short orderedShort = orderedBuffer.getShort();
// left channel only
if (idx % 4 == 0) {
pcmShorts[idxForSeconds++ % pcmShorts.length] = orderedShort;
if (BuildConfig.DEBUG) {
if (idx >= 6000 && idx < 6020) {
Log.i(WcvCodec.TAG, "decodePcm"
+ ", pcmName=" + pcmName
+ ", idx=" + idx
+ ", [0x" + Integer.toHexString(twoBytes[0] & 0xff)
+ " 0x" + Integer.toHexString(twoBytes[1] & 0xff)
+ "] = " + orderedShort
);
}
}
//Log.i(WcvCodec.TAG, "decodePcm" + ", idxForSeconds=" + idxForSeconds);
if (idxForSeconds % frame_of_audio == 0) {
sliceIdx += 1;
if (listener != null) {
PredictPcmRes res = new PredictPcmRes();
res.pcmUri = pcmUri;
res.pcmShorts = pcmShorts;
res.durIdx = sliceIdx;
res.durCnt = sliceCnt;
Log.d(TAG, "TFL::predictMp3, decodePcm, " + res.durIdx + " / " + res.durCnt);
listener.run(res, pcmResList);
}
}
}
}
Log.i(WcvCodec.TAG, "decodePcm"
+ ", sliceIdx=" + sliceIdx
+ ", sliceCnt=" + sliceCnt
+ ", accRead=" + accRead
);
} catch (Exception e) {
e.printStackTrace();
Log.d(TAG, "TFL::predictMp3, Exception = " + e);
} finally {
if (pcmFis != null) {
ReleaseUtil.release(pcmFis);
}
if (pcmFile != null) {
boolean res = pcmFile.delete();
}
}
}
public static void predictPcm(PredictPcmRes res, Context context, PredictPcmListener listener) {
try {
Log.d(TAG, "TFL::predictPcm, pcm length = " + res.pcmShorts.length);
// pcm to normalized floats
float[] normalizedFloats = new float[res.pcmShorts.length];
for (int i = 0; i < res.pcmShorts.length; i++) {
short ps = res.pcmShorts[i] != null ? res.pcmShorts[i] : 0;
normalizedFloats[i] = 1f * ps / (Short.MAX_VALUE + 1);
}
res.normalizedFloats = normalizedFloats;
Log.d(TAG, "TFL::predictPcm, normalizedFloats length = " + res.normalizedFloats.length);
// aurora
AuroraHelper ah = new AuroraHelper();
ah.PytorchModel(context);
ah.runInference(res.normalizedFloats);
// do predict on normalized floats
ModelExecutor executor = new ModelExecutor(context, false);
Pair<ArrayList<String>, ArrayList<Float>> pair = executor.execute(normalizedFloats);
res.pair = pair;
if (Lg.DEBUG) {
Log.i(TAG, "TFL::predictPcm, pair first = " + pair.getFirst().toString());
Log.i(TAG, "TFL::predictPcm, pair second = " + pair.getSecond().toString());
}
if (listener != null) {
listener.onSuccess();
}
} catch (Exception e) {
e.printStackTrace();
Log.d(TAG, "TFL::predictMp3, predictPcm, Exception = " + e);
}
}