#include <string.h>
#include <stdlib.h>
#include <stdarg.h>
#include <errno.h>
#include <unistd.h>
#include <termios.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <sys/select.h>
#include <serial.h>
#include <sprog.h>

int serial_setbaud_termios(struct termios *t, int baud);
int serial_select(int nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds, struct timeval *timeout);

/* supported baud codes to use with cfsetispeed, cfsetospeed - keep sorted! */

const struct baud_code baud_codes[] = {
  {50, B50},
  {75, B75},
  {110, B110},
  {134, B134},
  {150, B150},
  {200, B200},
  {300, B300},
  {600, B600},
  {1200, B1200},
  {1800, B1800},
  {2400, B2400},
  {4800, B4800},
  {9600, B9600},
  {19200, B19200},
  {38400, B38400},
  {57600, B57600},
  {115200, B115200},
  {230400, B230400},
  {-1, 0}
};

int serial_open(struct serial_device *port, const char *path, int baud) {
  struct termios attr;
  int fd;
  fd = open(path, O_RDWR);
  if(fd<0) {
    sprog_error("Unable to open serial port %s: %s\n", path, strerror(errno));
    exit(1);
  }
  
  tcgetattr(fd, &attr);
  /* 8 data bits, 1 stop bit, no parity */
  attr.c_iflag = IGNBRK;
  attr.c_oflag &= ~(OPOST | OLCUC | ONOCR | ONLRET | ONLCR); 
  attr.c_cflag = CS8 | CREAD | CLOCAL;
  attr.c_lflag = ICANON;
  serial_setbaud_termios(&attr, baud);
  tcsetattr(fd, TCSANOW, &attr);
  port->fd = fd;
  port->f = fdopen(fd, "r+");
  return fd;
}

void serial_close(struct serial_device *port) {
  close(port->fd);
}

int serial_setbaud(struct serial_device *port, int baud) {
  struct termios attr;
  tcgetattr(port->fd, &attr);
  baud = serial_setbaud_termios(&attr, baud);
  tcsetattr(port->fd, TCSANOW, &attr);
  return baud;
}

int serial_setbaud_termios(struct termios *t, int baud) {
  int i;
  int baudcode;
  
  /* find the corresponding baud code */
  
  for(i=0; baud_codes[i].baud>0; i++) {
    if(baud_codes[i].baud == baud)
      break;
    if(baud_codes[i].baud > baud) {
      if(i>0) i--;
      break;
    }
  }
  
  /* if the selected baud rate is greater than any of available rates, set the highest one */
  if(baud_codes[i].baud<=0) i--;
  
  if(baud_codes[i].baud != baud)
    sprog_error("Unsupported baud rate %d, using the nearest value %d\n", baud, baud_codes[i].baud);
  
  baudcode = baud_codes[i].code;
  cfsetispeed(t, baudcode);
  cfsetospeed(t, baudcode);
  return baud_codes[i].baud;
}

void serial_setline(struct serial_device *port, int line, int state) {
  int bits;
  int mask;
  if(line==SERIAL_DTR)
    mask = TIOCM_DTR;
  else
    mask = TIOCM_RTS;
  
  ioctl(port->fd, TIOCMGET, &bits);
  if(state)
    bits |= mask;
  else
    bits &= ~mask;
  ioctl(port->fd, TIOCMSET, &bits);
}

void serial_write(struct serial_device *port, const char *text) {
  int i;
  int b;
  int bytes;
  bytes = strlen(text);
  i = 0;
  
  while(bytes-i>0) {
    b = write(port->fd, &text[i], bytes-i);
    if(b<0) {
      sprog_error("Error while writing to the serial port: %s\n", strerror(errno));
      break;
    }
    i += b;
  }
}

int serial_read(struct serial_device *port, char *buf, int len, int timeout) {
  int bytes;
  fd_set readset;
  struct timeval tval;
  FD_ZERO(&readset);
  FD_SET(port->fd, &readset);
  
  tval.tv_sec = timeout/1000;
  tval.tv_usec = (timeout % 1000) * 1000;
  
  bytes = 0;
  while((serial_select(port->fd+1, &readset, NULL, NULL, &tval)>0) && (len-bytes)>0) {
    if(FD_ISSET(port->fd, &readset))
      bytes += read(port->fd, &buf[bytes], len-bytes);
  }
  return bytes;
}

int serial_select(int nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds, struct timeval *timeout) {
  int res;
  
   /* Linux automatically subtracts the time elapsed on awaiting
   * for an event when calling select(), The following code unifices
   * the behaviour under different implementations of UNIX.
   */
#ifndef __linux__
  struct timeval start_time
  struct timeval cur_time;
  long int remaining;
  
  gettimeofday(&start_time, NULL);
  res = pselect(nfds, readfds, writefds, exceptfds, timeout, NULL)
  gettimeofday(&cur_time, NULL);
  remaining = (timeout->tv_sec - (cur_time.tv_sec - start_time.tv_sec))*1000000 + timeout->tv_usec - (cur_time.tv_usec - start_time.tv_usec);
  if(remaining<0) remaining = 0;
  
  timeout.tv_sec = remaining/1000000;
  timeout.tv_usec = remmaining%1000000;
#else
  res = select(nfds, readfds, writefds, exceptfds, timeout);
#endif
 
  return res;
}