# coding=utf-8

import os, bisect
import transitfeed
import pytz

from flask import url_for
from datetime import datetime, timedelta
from dateutil.tz import tzlocal
from collections import OrderedDict

# Basic LRU cache implemented using OrderedDict
# This is mainly helpful for getting the full sorted schedule of a stop (time interpolation and sorting is expensive)
class GtfsCache(object):
	def __init__(self, maxItems = 400):
		self._store = OrderedDict()
		self._maxItems = maxItems

	# validTag = value to check if stored value should still be used (comparable to http etag)
	def put(self, key, object, validTag = -1):
		try:
			self._store.pop(key)
		except KeyError:
			if len(self._store) >= self._maxItems:
				self._store.popitem(last = False)

		self._store[key] = {'value': object, 'validTag': validTag}
		return object

	def get(self, key, validTag = -1, defaultValue = None):
		try:
			container = self._store.pop(key)
			self._store[key] = container
			if container['validTag'] == validTag:
				return container['value']
		except KeyError:
			pass

		return defaultValue

	def cachedFunc(self, key, fn, validTag = -1):
		cached = self.get(key, validTag)
		if cached != None:
			return cached

		object = fn()
		if object != None:
			self.put(key, object, validTag)

		return object

# handler for working with GTFS data directly, no need to have Föli API as a go between
# Using Google's transitfeed library (doesn't support python 3 for the time being)
# 
# Föli terms:
# 	trip = vuoro
# 	route = linja
class GtfsHandler(object):
	schedule = None
	locator = None

	_cache = None

	# other arguments are optional if created with defer = True
	def __init__(self, gtfsName = None, dir = None, defer = False):
		self._cache = GtfsCache()
		if not defer:
			self.load(gtfsName, dir)

	def load(self, feedName, dir = None):
		if not dir:
			dir = os.getcwd()

		feedPath = os.path.normpath(os.path.join(dir, feedName))
		if not os.path.isdir(feedPath):
			feedPath += '.zip'
			if not os.path.isfile(feedPath):
				raise Exception('Unable to locate GTFS data.')

		loader = transitfeed.Loader(feedPath)
		self.schedule = loader.Load()

		#print("trips")
		#print(self.schedule.trips)

		# filter unused stops
		for stopId, stop in self.schedule.stops.items():
			if not stop.GetTrips(self.schedule):
				del self.schedule.stops[stopId]

	# single underscore is a convention for "private" things (may or may not affect wildcard imports)
	def _stopToDict(self, stop):
		cached = self._cache.get(str(stop.stop_id) + '-info')
		if cached != None:
			return cached

		return self._cache.put(str(stop.stop_id) + '-info', {
			'stop_id': stop.stop_id,
			'stop_code': stop.stop_code,
			'stop_name': stop.stop_name, 
			'stop_lat': float(stop.stop_lat),
			'stop_lon': float(stop.stop_lon),
			'stop_url':  url_for('view', stopId = stop.stop_id, _external = True),
			'stop_timezone': stop.stop_timezone
		})

	def _getStopSchedule(self, stop, now):
		fn = lambda: sorted(((time, (trip, stIndex)) for (time, (trip, stIndex), tp) in stop.GetStopTimeTrips(self.schedule) if self._inactivePredicate(trip, now)))
		return self._cache.cachedFunc(str(stop.stop_id) + '-trips', fn, now.strftime("%Y%m%d"))

	def _getHeadsign(self, trip, index):
		headsign = None
		for st in trip.GetStopTimes()[index::-1]:
			if st.stop_headsign:
				headsign = st.stop_headsign
				break

		if not headsign:
			headsign = trip.trip_headsign

		return headsign

	def _getTripSlug(self, stop, trip):
		return '{0}.{1}'.format(stop.stop_id, stop.GetTrips().index(trip))

	def _getRouteName(self, route):
		trip_name = ''
		if route.route_short_name:
			trip_name += route.route_short_name
			if route.route_long_name:
				if len(trip_name):
					trip_name += " - "
				trip_name += route.route_long_name

		return trip_name

	def _getCurrentTime(self, stop):
		now = datetime.now(tzlocal())
		if stop.stop_timezone:
			now = now.astimezone(pytz.timezone(stop.stop_timezone))

		return now

	def _chooseStopTime(self, st, preferArrival, defaultTime = None):
		isArrival = preferArrival
		stSecs = st.arrival_secs if preferArrival else st.departure_secs
		if stSecs == None:
			isArrival = not isArrival
			stSecs = defaultTime if (defaultTime != None) else st.GetTimeSecs()

		return (isArrival, stSecs)

	def _routeToDict(self, route):
		result = {
			'route_id': route.route_id,
			'route_name': self._getRouteName(route),
			'short_name': route.route_short_name
		}

		if route.route_short_name:
			result['short_name'] = route.route_short_name

		return result

	def _tripToDict(self, trip, midnight, currentTime, arrivalTimes):
		stops = []
		for index, st in enumerate(trip.GetStopTimes()):
			isArrival, stSecs = self._chooseStopTime(st, arrivalTimes)

			result = {
				'time': midnight + timedelta(seconds = stSecs),
				'is_arrival': isArrival,
				'headsign': self._getHeadsign(trip, index),
				'stop': {
					'stop_id': st.stop.stop_id,
					'stop_code': st.stop.stop_code,
					'stop_name': st.stop.stop_name
				}
			}

			result['time_relative'] = (result['time'] - currentTime).total_seconds()

			if trip.trip_short_name:
				result['trip_name'] = trip.trip_short_name

			stops.append(result)

		result = self._routeToDict(self.schedule.GetRoute(trip.route_id))
		result['stops'] = stops
		return result

	def _tripToRoute(self, trip, shortOnly):
		route = self.schedule.GetRoute(trip.route_id)
		if shortOnly:
			return route.route_short_name if route.route_short_name else route.route_id

		return self._routeToDict(trip)

	def isValid(self):
		return self.schedule != None

	# returns a gtfs stop object as dictionary
	def getStop(self, stopId):
		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return {}

		return self._stopToDict(stop)

	def _inactivePredicate(self, trip, now):
		return not now or self.schedule.GetServicePeriod(trip.service_id).IsActiveOn(now.strftime("%Y%m%d"), date_object = now)

	# Note: unordered
	def getDaysRoutes(self, stopId, now = None, asList = False):
		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return []

		if not now:
			now = self._getCurrentTime(stop)

		routes = (self._tripToRoute(trip, asList) for trip in stop.GetTrips() if self._inactivePredicate(trip, now))
		if not asList:
			return {r['route_id']:r for r in routes}.values()

		return set(routes)

	def countDaysTrips(self, stopId, now = None):
		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return 0

		if not now:
			now = self._getCurrentTime(stop)

		# this function is useless without a call to getDaysTrips() in most cases so let's just cache the stop schedule early if we can
		return len(self._getStopSchedule(stop, now))

	def getDaysTrips(self, stopId, offset = 0, limit = -1, now = None, fullTrips = False):
		if not now:
			now = self.getCurrentTimeStop(stopId)

		# todo: things could be said about performance here as well, but I'd rather avoid duplicating the code
		return self.getNextTrips(
			stopId = stopId,
			now = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0),
			offset = offset, 
			limit = limit,
			arrivalTimes = False, 
			fullTrips = fullTrips
		)

	# Föli has a stop monitoring end point exposed that could be used here for much more accurate read, but we can't control what it returns
	# if we want to merge the two, that has to be done client side (JS) for efficient response times
	#  - we can't hold the stop monitoring (SM) responses in memory like we can for vehicle monitoring
	def getNextTrips(self, stopId, now = None, offset = 0, limit = 5, arrivalTimes = True, fullTrips = False):
		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return []

		if not now:
			now = self._getCurrentTime(stop)

		# todo: performance... caching it in memory is a decent stop gap since we can trade the memory for speed here
		trips = self._getStopSchedule(stop, now)

		# time is the relative time of day (seconds since midnight)
		midnight = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0)
		timeSecs = (now - midnight).total_seconds()
		if timeSecs != 0:
			trips = trips[bisect.bisect_left(trips, (timeSecs, 0)):]

		if offset > 0:
			trips = trips[offset:]

		if limit > 0:
			trips = trips[:limit]

		results = []
		tripCache = {}
		currentTime = self._getCurrentTime(stop).replace(microsecond = 0)
		for time, (trip, stIndex) in trips:
			st = trip.GetStopTimes()[stIndex]
			isArrival, stSecs = self._chooseStopTime(st, arrivalTimes, time)

			route = self.schedule.GetRoute(trip.route_id)
			result = {
				'time': midnight + timedelta(seconds = stSecs),
				'is_arrival': isArrival,
				'headsign': self._getHeadsign(trip, stIndex),
				'trip_slug': self._getTripSlug(st.stop, trip)
			}

			result['time_relative'] = (result['time'] - currentTime).total_seconds()

			if not fullTrips:
				result['route'] = self._routeToDict(route)
			else:
				# in case the same route has duplicate stops avoid double processing by caching the result here
				routeTrip = tripCache.get(trip.trip_id, None)
				if not routeTrip:
					routeTrip = tripCache[trip.trip_id] = self._tripToDict(trip, midnight, currentTime, arrivalTimes)

				result['route'] = routeTrip 

			if trip.trip_short_name:
				result['trip_name'] = trip.trip_short_name

			results.append(result)

		return results

	# index is the index in the result of an unfiltered call to stop.GetTrips(), see trip_slug
	def getStopTrip(self, stopId, index, now = None, arrivalTimes = False):
		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return []

		trip =  stop.GetTrips()[index]

		if not now:
			now = self._getCurrentTime(stop)

		if not self._inactivePredicate(trip, now):
			return []

		return self._tripToDict(
			trip = trip,
			midnight = now.replace(hour = 0, minute = 0, second = 0, microsecond = 0),
			currentTime = self._getCurrentTime(stop).replace(microsecond = 0),
			arrivalTimes = arrivalTimes
		)

	def getCurrentTimeStop(self, stopId):
		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return datetime.now(tzlocal())

		return self._getCurrentTime(stop)

	def getStopTimeZone(self, stopId):
		stop = self.schedule.stops.get(stopId, None)
		if not stop or not stop.stop_timezone:
			return tzlocal().tzname()

		return stop.stop_timezone

	def getShape(self, lineref):
		print("looking for shape from lineref " + lineref)

		# find route id from routes.txt
		# then find shape id from trips.txt
		# with shape id find shape from shapes.txt
		routeId = None
		for key, route in self.schedule.routes.iteritems():
			if (route.route_short_name == lineref):
				routeId = key

		print("found routeId ")
		print(routeId)
		shapeId = None
		for key, trip in self.schedule.trips.iteritems():
			if (trip.route_id == routeId):
				shapeId = trip.shape_id
				break

		shape = self.schedule._shapes[shapeId].points
		return shape

	def setLocator(self, locatorService):
		self.locator = locatorService

	def locateRoutedVehicles(self, stopId):
		if not self.locator:
			return (datetime.utcnow(), {})

		routes = self.getDaysRoutes(stopId, asList = True)
		if not routes:
			return (datetime.utcnow(), {})

		return self.locator.getLocations(routes)

	def locateNearbyVehicles(self, stopId, maxDist):
		if not self.locator:
			return (datetime.utcnow(), {})

		stop = self.schedule.stops.get(stopId, None)
		if not stop:
			return (datetime.utcnow(), {})

		return self.locator.getNearbyVehicles(float(stop.stop_lat), float(stop.stop_lon), maxDist)
