From c7688c6bcadeba0e6131f57a3d7f09a475f14e48 Mon Sep 17 00:00:00 2001
From: Marco D'Alessio <marco@wrecklab.com>
Date: Sun, 18 Oct 2020 10:50:32 +0200
Subject: [PATCH] tmc2130: Add spi daisy chain support

This patch adds the ability to daisy-chain multiple tmc2130 and
tmc5160 drivers.

Signed-off-by: Marco D'Alessio <marco@wrecklab.com>
Signed-off-by: Kevin O'Connor <kevin@koconnor.net>
---
 docs/Config_Reference.md | 12 +++++++
 klippy/extras/bus.py     |  4 +--
 klippy/extras/tmc2130.py | 73 ++++++++++++++++++++++++++++++++--------
 3 files changed, 73 insertions(+), 16 deletions(-)

diff --git a/docs/Config_Reference.md b/docs/Config_Reference.md
index fee1ec204..60f344d4a 100644
--- a/docs/Config_Reference.md
+++ b/docs/Config_Reference.md
@@ -2580,6 +2580,12 @@ cs_pin:
 #spi_software_miso_pin:
 #   See the "common SPI settings" section for a description of the
 #   above parameters.
+#chain_position:
+#chain_length:
+#   These parameters configure an SPI daisy chain. The two parameters
+#   define the stepper position in the chain and the total chain length.
+#   Position 1 corresponds to the stepper that connects to the MOSI signal.
+#   The default is to not use an SPI daisy chain.
 #interpolate: True
 #   If true, enable step interpolation (the driver will internally
 #   step at a rate of 256 micro-steps). The default is True.
@@ -2822,6 +2828,12 @@ cs_pin:
 #spi_software_miso_pin:
 #   See the "common SPI settings" section for a description of the
 #   above parameters.
+#chain_position:
+#chain_length:
+#   These parameters configure an SPI daisy chain. The two parameters
+#   define the stepper position in the chain and the total chain length.
+#   Position 1 corresponds to the stepper that connects to the MOSI signal.
+#   The default is to not use an SPI daisy chain.
 #interpolate: True
 #   If true, enable step interpolation (the driver will internally
 #   step at a rate of 256 micro-steps). The default is True.
diff --git a/klippy/extras/bus.py b/klippy/extras/bus.py
index 7c6ae4401..09259b9f1 100644
--- a/klippy/extras/bus.py
+++ b/klippy/extras/bus.py
@@ -97,11 +97,11 @@ class MCU_SPI:
 
 # Helper to setup an spi bus from settings in a config section
 def MCU_SPI_from_config(config, mode, pin_option="cs_pin",
-                        default_speed=100000):
+                        default_speed=100000, share_type=None):
     # Determine pin from config
     ppins = config.get_printer().lookup_object("pins")
     cs_pin = config.get(pin_option)
-    cs_pin_params = ppins.lookup_pin(cs_pin)
+    cs_pin_params = ppins.lookup_pin(cs_pin, share_type=share_type)
     pin = cs_pin_params['pin']
     if pin == 'None':
         ppins.reset_pin_sharing(cs_pin_params)
diff --git a/klippy/extras/tmc2130.py b/klippy/extras/tmc2130.py
index 891de8580..f300afbcb 100644
--- a/klippy/extras/tmc2130.py
+++ b/klippy/extras/tmc2130.py
@@ -170,12 +170,66 @@ class TMCCurrentHelper:
 # TMC2130 SPI
 ######################################################################
 
+class MCU_TMC_SPI_chain:
+    def __init__(self, config, chain_len=1):
+        self.printer = config.get_printer()
+        self.chain_len = chain_len
+        self.mutex = self.printer.get_reactor().mutex()
+        share = None
+        if chain_len > 1:
+            share = "tmc_spi_cs"
+        self.spi = bus.MCU_SPI_from_config(config, 3, default_speed=4000000,
+                                           share_type=share)
+        self.taken_chain_positions = []
+    def _build_cmd(self, data, chain_pos):
+        return ([0x00] * ((self.chain_len - chain_pos) * 5) +
+                data + [0x00] * ((chain_pos - 1) * 5))
+    def reg_read(self, reg, chain_pos):
+        cmd = self._build_cmd([reg, 0x00, 0x00, 0x00, 0x00], chain_pos)
+        self.spi.spi_send(cmd)
+        if self.printer.get_start_args().get('debugoutput') is not None:
+            return 0
+        params = self.spi.spi_transfer(cmd)
+        pr = bytearray(params['response'])
+        pr = pr[(self.chain_len - chain_pos) * 5 :
+                (self.chain_len - chain_pos + 1) * 5]
+        return (pr[1] << 24) | (pr[2] << 16) | (pr[3] << 8) | pr[4]
+    def reg_write(self, reg, val, chain_pos, print_time=None):
+        minclock = 0
+        if print_time is not None:
+            minclock = self.spi.get_mcu().print_time_to_clock(print_time)
+        data = [(reg | 0x80) & 0xff, (val >> 24) & 0xff, (val >> 16) & 0xff,
+                (val >> 8) & 0xff, val & 0xff]
+        self.spi.spi_send(self._build_cmd(data, chain_pos), minclock)
+
+# Helper to setup an spi daisy chain bus from settings in a config section
+def lookup_tmc_spi_chain(config):
+    chain_len = config.getint('chain_length', None, minval=2)
+    if chain_len is None:
+        # Simple, non daisy chained SPI connection
+        return MCU_TMC_SPI_chain(config, 1), 1
+
+    # Shared SPI bus - lookup existing MCU_TMC_SPI_chain
+    ppins = config.get_printer().lookup_object("pins")
+    cs_pin_params = ppins.lookup_pin(config.get('cs_pin'),
+                                     share_type="tmc_spi_cs")
+    tmc_spi = cs_pin_params.get('class')
+    if tmc_spi is None:
+        tmc_spi = cs_pin_params['class'] = MCU_TMC_SPI_chain(config, chain_len)
+    if chain_len != tmc_spi.chain_len:
+        raise config.error("TMC SPI chain must have same length")
+    chain_pos = config.getint('chain_position', minval=1, maxval=chain_len)
+    if chain_pos in tmc_spi.taken_chain_positions:
+        raise config.error("TMC SPI chain can not have duplicate position")
+    tmc_spi.taken_chain_positions.append(chain_pos)
+    return tmc_spi, chain_pos
+
 # Helper code for working with TMC devices via SPI
 class MCU_TMC_SPI:
     def __init__(self, config, name_to_reg, fields):
         self.printer = config.get_printer()
-        self.mutex = self.printer.get_reactor().mutex()
-        self.spi = bus.MCU_SPI_from_config(config, 3, default_speed=4000000)
+        self.tmc_spi, self.chain_pos = lookup_tmc_spi_chain(config)
+        self.mutex = self.tmc_spi.mutex
         self.name_to_reg = name_to_reg
         self.fields = fields
     def get_fields(self):
@@ -183,21 +237,12 @@ class MCU_TMC_SPI:
     def get_register(self, reg_name):
         reg = self.name_to_reg[reg_name]
         with self.mutex:
-            self.spi.spi_send([reg, 0x00, 0x00, 0x00, 0x00])
-            if self.printer.get_start_args().get('debugoutput') is not None:
-                return 0
-            params = self.spi.spi_transfer([reg, 0x00, 0x00, 0x00, 0x00])
-        pr = bytearray(params['response'])
-        return (pr[1] << 24) | (pr[2] << 16) | (pr[3] << 8) | pr[4]
+            read = self.tmc_spi.reg_read(reg, self.chain_pos)
+        return read
     def set_register(self, reg_name, val, print_time=None):
-        minclock = 0
-        if print_time is not None:
-            minclock = self.spi.get_mcu().print_time_to_clock(print_time)
         reg = self.name_to_reg[reg_name]
-        data = [(reg | 0x80) & 0xff, (val >> 24) & 0xff, (val >> 16) & 0xff,
-                (val >> 8) & 0xff, val & 0xff]
         with self.mutex:
-            self.spi.spi_send(data, minclock)
+            self.tmc_spi.reg_write(reg, val, self.chain_pos, print_time)
 
 
 ######################################################################