import math import random import logging from typing import Optional from mathema.actors.actor import Actor log = logging.getLogger(__name__) def tanh(x): return math.tanh(x) class Neuron(Actor): def __init__(self, nid, cx_pid, af_name, input_idps, output_pids, bias: Optional[float] = None): super().__init__(f"Neuron-{nid}") self.nid = nid self.cx_pid = cx_pid self.af = tanh if af_name == "tanh" else tanh self.inputs = {} self.order = [] self._has_recurrent = False """ for (inp_id, weights) in input_idps: self.order.append(inp_id) self.inputs[inp_id] = {"weights": list(weights), "got": False, "val": None} """ self.bias = float(bias) if bias is not None else 0.0 for inp_id, weights, recurrent in input_idps: recurrent = bool(recurrent) if inp_id == "bias": self.bias = float(weights[0]) else: self.order.append(inp_id) self.inputs[inp_id] = { "weights": list(weights), "got": False, "val": [], "recurrent": recurrent, "next_val": [] } if recurrent: self._has_recurrent = True self._backup_inputs = None self._backup_bias = None self.outputs = output_pids log.debug(f"Neuron {nid}: inputs={list(self.inputs.keys())}, bias={self.bias}") async def run(self): while True: msg = await self.inbox.get() tag = msg[0] if tag == "forward": _, from_id, data = msg if from_id not in self.inputs: continue slot = self.inputs[from_id] if not isinstance(data, list): data = [float(data)] if slot["recurrent"]: slot["next_val"] = data else: slot["got"] = True slot["val"] = data if all(self.inputs[i]["got"] for i in self.order): acc = 0.0 for i in self.order: w = self.inputs[i]["weights"] v = self.inputs[i]["val"] if len(w) != len(v): raise ValueError(f"Lengths of weights and values must be equal") acc += sum(wj * vj for wj, vj in zip(w, v)) out = self.af(acc + self.bias) for pid in self.outputs: await pid.send(("forward", self.nid, [out])) for i in self.order: self.inputs[i]["got"] = False self.inputs[i]["val"] = [] log.debug(f"Neuron {self.nid}: input_sum={acc + self.bias:.3f}, output={out:.3f}") elif tag == "tick": if self.order and all(self.inputs[i]["got"] for i in self.order): acc = 0.0 for i in self.order: w = self.inputs[i]["weights"] v = self.inputs[i]["val"] if len(w) != len(v): raise ValueError("Lengths of weights and values must be equal") acc += sum(wj * vj for wj, vj in zip(w, v)) out = self.af(acc + self.bias) for pid in self.outputs: await pid.send(("forward", self.nid, [out])) for i in self.order: self.inputs[i]["got"] = False self.inputs[i]["val"] = [] log.debug(f"Neuron {self.nid}: input_sum={acc + self.bias:.3f}, output={out:.3f}") elif tag == "get_backup": idps = [(i, self.inputs[i]["weights"]) for i in self.order] idps.append(("bias", [self.bias])) await self.cx_pid.send(("backup_from_neuron", self.nid, idps)) elif tag == "weight_backup": log.debug(f"Neuron {self.nid}: backing up weights") self._backup_inputs = {k: {"weights": v["weights"][:]} for k, v in self.inputs.items()} self._backup_bias = self.bias elif tag == "weight_restore": if self._backup_inputs is not None: for k in self.inputs: self.inputs[k]["weights"] = self._backup_inputs[k]["weights"][:] self.bias = self._backup_bias elif tag == "weight_perturb": log.debug( f"Neuron {self.nid}: perturbing {len([w for i in self.order for w in self.inputs[i]['weights']])}" f"weights") tot_w = sum(len(self.inputs[i]["weights"]) for i in self.order) + 1 mp = 1 / math.sqrt(tot_w) delta_mag = 2.0 * math.pi sat_lim = 2.0 * math.pi for i in self.order: ws = self.inputs[i]["weights"] for j in range(len(ws)): if random.random() < mp: ws[j] = _sat(ws[j] + (random.random() - 0.5) * delta_mag, -sat_lim, sat_lim) if random.random() < mp: self.bias = _sat(self.bias + (random.random() - 0.5) * delta_mag, -sat_lim, sat_lim) elif tag == "cycle_start": for i in self.order: slot = self.inputs[i] if slot["recurrent"]: nv = slot["next_val"] if not nv: w = slot["weights"] nv = [0.0] * max(1, len(w)) slot["val"] = nv slot["got"] = True slot["next_val"] = [] else: slot["got"] = False slot["val"] = [] elif tag == "terminate": return def _sat(val, lo, hi): return lo if val < lo else (hi if val > hi else val)