目录
类属性
方法
__init__(self)
async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool)
async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool)
async def reset_shard(self, shard: Shard)
async def ensure_shard(self, shard: Shard)
总结
step
reset
构造函数 __init__
方法 step
方法 __call__
方法 reset
注意点
sharded_inference_engine:MLXDynamicShardInferenceEngine
这段代码定义了一个名为 MLXDynamicShardInferenceEngine
的类,它继承自一个名为 InferenceEngine
的基类(尽管基类的具体实现没有给出,但我们可以从子类推断出一些行为)。这个类是为了在分布式或分片环境中进行模型推理而设计的,特别是针对那些被分片存储或部署的模型。下面是对这个类及其方法的详细解释:
类属性
shard
: 用于存储当前激活的分片信息。在推理过程中,这个属性会指向当前正在使用的模型分片。
方法
__init__(self)
- 类的构造函数。它