You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
6.8 KiB

#!/usr/bin/env python3
import logging
import sys
import serial
import threading
from threading import Thread
import time
import binascii
from base64 import b64decode, b64encode
from PyCRC.CRC32 import CRC32
import struct
import messages_pb2 as messages
from google.protobuf.message import Message
log = logging.getLogger("ugv")
6 years ago
class UGVComms:
MAX_WRITE_RETRY = 5
RETRY_TIME = 1.5
def __init__(self, serial_port: serial.Serial, on_msg_received=None):
6 years ago
self.ser = serial_port
self.on_msg_received = on_msg_received
self.msg_acks = []
self.ack_cv = threading.Condition()
self.next_command_id = 1
self.last_status = None
self.last_status_time = None
self.rx_thread = None
self.is_running = False
self.log_file = None
def write_base64(self, data: bytes):
crc = CRC32().calculate(data)
data_with_checksum = bytearray(data)
data_with_checksum.extend(struct.pack('<L', crc))
encoded = b64encode(data_with_checksum)
self.ser.write(encoded)
self.ser.write(b'\n')
def write_message(self, msg: Message):
log.debug("writing message: %s", msg)
data = msg.SerializeToString()
self.write_base64(data)
def write_command(self, command, retry=True):
cmdid = self.next_command_id
self.next_command_id += 1
gmsg = messages.GroundMessage()
if type(command) is int:
gmsg.command.type = command
else:
gmsg.command.CopyFrom(command)
gmsg.command.id = cmdid
self.write_message(gmsg)
last_write_time = time.time()
if not retry:
return
tries = UGVComms.MAX_WRITE_RETRY
with self.ack_cv:
while tries > 0:
if cmdid in self.msg_acks:
self.msg_acks.remove(cmdid)
log.debug("received ack for command")
return
time_left = time.time() - last_write_time
if time_left >= self.RETRY_TIME:
log.warning("retry writing command")
self.write_message(gmsg)
last_write_time = time.time()
tries -= 1
self.ack_cv.wait(timeout=time_left)
raise TimeoutError("Timeout waiting for command ack")
def read_message(self):
data = self.ser.read_until(terminator=b'\n')
if len(data) is 0:
6 years ago
return None
try:
decoded = b64decode(data, validate=True)
except binascii.Error:
log.warning("read bad data: %s", data)
6 years ago
self.ser.flush()
return None
if len(decoded) < 4:
log.warning('Message too short ({} bytes)'.format(len(decoded)))
return None
msgcrc, = struct.unpack('<L', decoded[-4:])
calccrc = CRC32().calculate(decoded[:-4])
if msgcrc != calccrc:
log.warning('Checksum did not match ({} != {})'.format(msgcrc, calccrc))
return None
msg = messages.UGV_Message()
msg.ParseFromString(decoded[:-4])
return msg
def process_message(self, msg: messages.UGV_Message):
if msg is None:
return
log.debug("received UGV message: %s", msg)
if self.on_msg_received:
self.on_msg_received(msg)
if self.log_file:
print('[{}] UGV_Message: {}'.format(time.strftime('%Y-%b-%d %H:%M:%S'), msg), file=self.log_file)
if msg.HasField("command_ack"):
with self.ack_cv:
self.msg_acks.append(msg.command_ack)
self.ack_cv.notify()
elif msg.HasField("status"):
self.last_status = msg.status
self.last_status_time = time.time()
else:
log.warning("unknown UGV message: %s", msg)
def start(self):
if self.is_running:
log.warning("RX thread already running")
return False
self.is_running = True
self.rx_thread = Thread(target=self.__rx_thread_entry, daemon=True)
self.rx_thread.start()
log.debug("started RX thread")
return True
def stop(self):
if not self.is_running:
return False
self.is_running = False
self.ser.close()
self.rx_thread.join()
return True
def save_logs(self, file):
self.log_file = open(file, mode='a')
def __rx_thread_entry(self):
try:
while self.is_running and self.ser.is_open:
try:
msg = self.read_message()
self.process_message(msg)
except serial.SerialException:
if not self.ser.is_open or not self.is_running: # port was probably just closed
return
log.error("serial error", exc_info=True)
return
except Exception:
log.error("error reading message", exc_info=True)
continue
finally:
if self.log_file:
self.log_file.close()
def main():
if len(sys.argv) >= 2:
ser_url = sys.argv[1]
else:
ser_url = "hwgrep://"
ser = serial.serial_for_url(ser_url, baudrate=9600, parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE, bytesize=serial.EIGHTBITS,
timeout=0.5)
ugv = UGVComms(ser)
ugv.start()
time.sleep(0.2)
try:
cmd = messages.GroundCommand()
cmd.type = messages.CMD_SET_TARGET
cmd.target_location.latitude = 34.068415
cmd.target_location.longitude = -118.443217
# ugv.write_command(cmd)
cmd.type = messages.CMD_SET_CONFIG
cmd.config.angle_pid.kp = 0.10
6 years ago
cmd.config.angle_pid.ki = 0 # .00005
cmd.config.angle_pid.kd = 0.4
cmd.config.angle_pid.max_output = 0.5
cmd.config.angle_pid.max_i_error = 15.0
cmd.config.min_target_dist = 10.0
cmd.config.min_flip_pitch = 90.0
ugv.write_command(cmd)
while True:
if ugv.last_status is None or ugv.last_status.state is not messages.STATE_DRIVE_HEADING:
cmd = messages.GroundCommand()
cmd.type = messages.CMD_DRIVE_HEADING
cmd.drive_heading.heading = -115.0 - 180
cmd.drive_heading.power = 0.3
ugv.write_command(cmd)
time.sleep(2.0)
except KeyboardInterrupt:
ugv.write_command(messages.CMD_DISABLE)
log.info("exiting...")
finally:
ugv.ser.flush()
ugv.ser.close()
ugv.stop()
if __name__ == "__main__":
logging.basicConfig(format='%(asctime)s [%(name)s] %(levelname)s: %(message)s', datefmt='%Y-%b-%d %H:%M:%S')
log.setLevel(logging.DEBUG)
main()