状态模式 - 从多个选项中确定下一个状态

问题描述 投票:0回答:1

我正在 Python 中使用状态设计模式实现噪声门

我的实现采用音频样本数组,并使用噪声门的参数、音频样本幅度值和噪声门的state,确定 [0, 1] 范围内的系数值,该值应该是乘以当前音频样本值。

我定义的状态是

OpenState
ClosedState
OpeningState
ClosingState
。我相信下图包含了我需要考虑的所有状态转换。

State transitions

当门处于

ClosingState
时,有两种可能的转变:

  1. ClosingState
    ->
    ClosedState
    - 如果释放期已过且在此期间没有另一个峰值超过阈值,则会发生这种情况。
  2. ClosingState
    ->
    OpenState
    - 如果在发布期间的某个时刻峰值超过阈值,则会发生这种情况。

我的代码中决定转换到哪个状态的部分是

ClosingState
类中的这个方法。

def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:        
    '''
    There are two possible states that we can transition to from ClosingState.
    Feels strange to introduce conditionals to determine state transition(?)
    '''
    # This doesn't feel right introducing these conditionals here.
    if sample_mag > self.context.lin_thresh:
        self.transition_pending = True
        self.new_state = OpenState()
        return True
    
    if self.sample_counter >= self.context.release_period_in_samples-1:
        self.transition_pending = True
        self.new_state = ClosedState()
        return True

我的问题只是是否可以使用这些条件来确定要转换到哪个状态。感觉就像重新引入使用状态模式摆脱的代码类型,但替代方案对我来说并不明显。


编辑:我想我已经为自己澄清了这一点。使用状态模式可以让我们摆脱如下所示的代码:

if state == "closed":
    # do something
elif state == "open":
    # do something
elif state == "closing":
    # do something
elif state == "opening":
    # do something

我质疑其有效性的条件句与此不同。我正在根据数据检查一些条件,而不是检查我处于哪个状态。


下面是一个最小的例子。对于我上面的概念性问题,这可能不需要,但我将其包括在内,以防万一。示例音频文件可以在这里找到。

SO_ramp_functions.py

import numpy as np

def ramp_linear_increase(num_points):
    ''' Function defining a linear increase from 0 to 1 in num_points samples '''
    return np.linspace(0, 1, num_points)

def ramp_linear_decrease(num_points):
    ''' Function defining a linear decrease from 1 to 0 in num_points samples '''
    return np.linspace(1, 0, num_points)

def ramp_poly_increase(num_points):
    ''' Generate an array of coefficient values for the attack period '''
    x = np.arange(num_points, 0, -1)
    attack_coef_arr = 1 - (x/num_points)**4
    
    # Make sure the start and end are 0 and 1, respectively
    attack_coef_arr[0] = 0
    attack_coef_arr[-1] = 1
    
    return attack_coef_arr


def ramp_poly_decrease(num_points):
    ''' Generate an array of coefficient values for the release period '''
    x = np.arange(num_points)
    release_coef_arr = 1 - (x/num_points)**4
    
    # Make sure the start and end are 1 and 0, respectively
    release_coef_arr[0] = 1
    release_coef_arr[-1] = 0
    
    return release_coef_arr

SO_gate_states.py

from abc import ABC, abstractmethod


class State(ABC):
    """
    The base State class declares methods that all concrete States should
    implement and also provides a backreference to the Context object,
    associated with the State. This backreference can be used by States to
    transition the Context to another State.
    """

    @property
    def context(self):
        return self._context


    @context.setter
    def context(self, context) -> None:
        self._context = context


    @abstractmethod
    def get_sample_coefficient(self, sample_mag: float) -> float:
        pass
    
    
    @abstractmethod
    def check_if_state_transition_is_due(self, sample_mag: float=None) -> None:
        pass
    
    
    @abstractmethod
    def on_entry(self):
        pass
    
    
    @abstractmethod
    def on_exit(self):
        pass


"""
Concrete States implement various behaviors, associated with a state of the
Context.
"""

class ClosedState(State):
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        ''' 
        Get the appropriate coefficient value to multiply with the current
        audio sample value.
        
        In the closed state, the coefficient is always 0.0.
        '''
        
        self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
        return 0.0
        
    
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
        '''
        Check if a condition is met that initiates a transition.
        For ClosedState, we want to check if the sample magnitude exceeds the threshold.
        '''
        return sample_mag > self.context.lin_thresh
    
    
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        pass
        
        
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(OpeningState())


class OpeningState(State):
    '''
    - In OpeningState, the coefficient is determined by the shape of the
        specified attack ramp.
    
    - The only state we can transition to from OpeningState is OpenState.
    '''
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        
        self.transition_pending = self.check_if_state_transition_is_due()
        if self.transition_pending:
            return 1.0
        else:
            # Get a value from the gate's attack ramp
            return self.context.attack_ramp[self.sample_counter]
        
        
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
        # Transition to OpenState occurs once attack period has elapsed
        return self.sample_counter >= self.context.attack_period_in_samples
    
    
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(OpenState())
            self.on_exit()
    
    
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        # This may not be needed, since we construct a new instance when
        # transitioning, but it may make it more robust
        self.sample_counter = 0
    

class OpenState(State):
    '''
    In OpenState, the coefficient is always 1.0.
    The only state we can transition to from OpenState is ClosingState.
    '''
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
        return 1.0
    
    
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
        # The gate can't transition before its hold period has elapsed
        if self.sample_counter < self.context.hold_period_in_samples:
            return False
        else:
            # If the signal magnitude falls below the threshold, we want to
            # transition to ClosingState.
            return sample_mag < self.context.lin_thresh
    
    
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        # This may not be needed, since we construct a new instance when
        # transitioning, but it may make it more robust
        self.sample_counter = 0
        
        
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(ClosingState())
            self.on_exit()
    

class ClosingState(State):
    '''    
    - The coefficient is determined by the shape of the specified release ramp.
    - The state can transition to either ClosedState or OpenState.
    '''
    
    def __init__(self):
        self.sample_counter = 0
        self.transition_pending = False
        self.new_state = None
    
    
    def get_sample_coefficient(self, sample_mag: float=0) -> float:
        self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
        return self.context.release_ramp[self.sample_counter]
        
        
    def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:        
        '''
        There are two possible states that we can transition to from ClosingState.
        Feels strange to introduce conditionals to determine state transition(?)
        '''
        # This doesn't feel right introducing these conditionals here.
        if sample_mag > self.context.lin_thresh:
            self.transition_pending = True
            self.new_state = OpenState()
            return True

        if self.sample_counter >= self.context.release_period_in_samples-1:
            self.transition_pending = True
            self.new_state = ClosedState()
            return True
        
        
    def handle_state_transition(self):
        if self.transition_pending:
            self.context.transition_to(self.new_state)
            self.on_exit()
        
        
    def on_entry(self):
        pass
    
    
    def on_exit(self):
        # This may not be needed, since we construct a new instance when
        # transitioning, but it may make it more robust
        self.sample_counter = 0


SO_noise_gate_state_pattern.py

import numpy as np
import SO_ramp_functions as rf

'''
The original template code is found here:
    https://refactoring.guru/design-patterns/state/python/example
'''

class AudioConfig:
    '''
    Values that configure audio playback, so they can be set indepdendently
    of, and shared between, different objects that need them.
    '''
    def __init__(self, fs):
        self.fs = fs


class Context:
    """
    This class represents the noise gate.
    
    The Context defines the interface of interest to clients. It also maintains
    a reference to an instance of a State subclass, which represents the current
    state of the Context.
    """

    def __init__(self, audio_config, state) -> None:
        self.audio_config = audio_config
        self.transition_to(state)
        
        # Specify an initial threshold value in dBFS
        self.thresh = -20
                
        # Specify attack, hold, release, and lookahead periods in seconds
        self.attack_time = 0.005  # seconds
        self.hold_time = 0.05  # seconds
        self.release_time = 0.1  # seconds
        self.lookahead_time = 0.005 # seconds
        
        # Calculate attack, hold, and release periods in samples
        self.attack_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.attack_time)        
        self.hold_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.hold_time)
        self.release_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.release_time)
        self.lookahead_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.lookahead_time)
        
        # Define the attack and release multiplier ramps - use strategy pattern?
        self.attack_ramp = rf.ramp_poly_increase(num_points=self.attack_period_in_samples)
        self.release_ramp = rf.ramp_poly_decrease(num_points=self.release_period_in_samples)
        
        # Initialise an attribute to store the processed result
        self.processed_array = None
        self.coef_array = None
        
        # Padding to enable lookahead (a bit of a hack)
        self.lookahead_pad_samples = self.lookahead_period_in_samples#2000
        
        # Attributes for debugging
        self.text_output = []


    def transition_to(self, state):
        """
        The Context allows changing the State object at runtime.
        """

        ##print(f"Context: Transition to {type(state).__name__}")
        self._state = state
        self._state.context = self


    # Setters for gate parameters
    def set_attack_time(self, new_attack_time: float) -> None:
        self.attack_time = new_attack_time
        self.attack_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.attack_time)
    
    
    def set_hold_time(self, new_hold_time: float) -> None:
        self.hold_time = new_hold_time
        self.hold_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.hold_time)
        
        
    def set_release_time(self, new_release_time: float) -> None:
        self.release_time = new_release_time
        self.release_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.release_time)


    @property
    def thresh(self) -> int:
        return self._thresh
    
    
    @thresh.setter
    def thresh(self, new_thresh: int) -> None:
        self._thresh = new_thresh


    @property
    def lin_thresh(self) -> float:
        return self.dBFS_to_lin(self.thresh)
    
    
    # These staticmethods could equally be defined outside the class
    @staticmethod
    def dBFS_to_lin(dBFS_val):
        ''' Helper method to convert a dBFS value to a linear value [0, 1] '''
        return 10 ** (dBFS_val / 20)
        

    @staticmethod
    def seconds_to_samples(fs, seconds_val):
        ''' Helper method to convert a time (seconds) value to a number of samples '''
        return int(fs * seconds_val)

    
    def process_audio_block(self, audio_array=None):
        '''
        Process an array of audio samples according to the gate's parameters,
        current state, and the sample values in the audio array.
        This implementation includes lookahead logic.
        
        '''
        
        # Initialise an array of coefficient values of the same length as audio_array
        # Set initial coefficient values outside valid range [0, 1] for easier debugging
        self.coef_array = np.ones(len(audio_array))[:-self.lookahead_pad_samples] * 2
        # Get the magnitude values of the audio array
        self.mag_array = np.abs(audio_array)

        # Iterate through the samples of the mag_arr, updating coef_array values
        for i, sample_mag in enumerate(self.mag_array[:-self.lookahead_pad_samples]):    
            # Get the coefficient value for the current sample, considering a lookahead period
            self.coef_array[i] = self._state.get_sample_coefficient(self.mag_array[i + self.lookahead_period_in_samples])
            # Increment the counter for tracking the samples elapsed in the current state
            self._state.sample_counter += 1
            # Create a log of the state and samples elapsed, for debugging
            self.text_output.append(f"{type(self._state).__name__}. {self._state.sample_counter}. {self.coef_array[i]:.3f}")
            # After processing the current sample, check if a transition is due
            self._state.handle_state_transition()
            
        self.processed_array = self.coef_array * audio_array[:-self.lookahead_pad_samples]

main.py

'''
Driver code for the noise gate using the state pattern.

'''

from SO_noise_gate_state_pattern import AudioConfig, Context
from SO_gate_states import ClosedState
import numpy as np
import audiofile
import matplotlib.pyplot as plt
import time


# Define some helper/test functions
def load_audio(fpath):
    data, fs = audiofile.read(fpath)
    data = data.T
    if len(data.shape) == 2:
        data = data[:,0]    # convert to mono
    return data


def test_gate_coef_values_are_valid(coef_arr):
    print("Testing gate coef_array values")
    assert(np.all([0<=val<=1 for val in coef_arr]))


if __name__ == "__main__":
    
    # The client code.
    # Configure some audio properties
    audio_config = AudioConfig(fs=44100)
    
    # Create a "context" instance (this is like the NoiseGate class)
    context = Context(audio_config, ClosedState())
    
    # Load audio from file
    sig = load_audio(fpath="./snare_test.wav")
    # Zero-pad the audio array to enable lookahead (experimental)
    sig = np.concatenate((sig, np.zeros(context.lookahead_pad_samples)))
    
    # Process the whole array and time it
    start_time = time.perf_counter()
    context.process_audio_block(sig)
    end_time = time.perf_counter()
    print(f"Time taken to process {len(sig)/audio_config.fs:.2f} seconds of audio: {end_time - start_time:.2f} seconds")
    
    # Some testing on the result
    test_gate_coef_values_are_valid(context.coef_array)
    
    # Plot the result
    plt.plot(context.mag_array, color='blue', linewidth=1, label='signal magnitude')
    plt.plot(context.coef_array, color='green', label='gate coefficient')
    plt.plot(np.abs(context.processed_array), color='orange', label='gate output')
    plt.axhline(context.lin_thresh, color='black', linewidth=1, label='gate threshold')
    plt.legend()
    plt.show()

python audio design-patterns state-pattern
1个回答
0
投票

输入逻辑来选择正确的状态是可以的。但是,如果您关心管理该代码,则可以使用其他模式来管理复杂性。我认为工厂方法模式或责任链模式可能很有用。然而,过度使用设计模式可能会使您的代码变得复杂。我宁愿将可能变化的代码包装起来,并将它们封装在一个具有有意义名称的函数中,并使我的 state_transition 方法清晰。

因此,如果您的条件很脆弱(它们需要灵活地接受未来的变化),请创建函数来代表它们。但如果您的逻辑需要灵活,请为您的转换逻辑创建函数。

感觉就像重新引入使用状态的代码类型 图案摆脱

请记住,状态模式的目标是分离每个状态以独立管理它们,因此会减少副作用。状态模式的目的并不是要降低每个状态的复杂性。为此,您应该考虑应用其他模式。

© www.soinside.com 2019 - 2024. All rights reserved.